package sidechannel

import (
	"bytes"
	"context"
	"crypto/rand"
	"io"
	"net"
	"sync"
	"testing"

	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/gitaly/v16/internal/grpc/backchannel"
	"gitlab.com/gitlab-org/gitaly/v16/internal/grpc/listenmux"
	"gitlab.com/gitlab-org/gitaly/v16/internal/testhelper"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	healthpb "google.golang.org/grpc/health/grpc_health_v1"
)

func TestSidechannel(t *testing.T) {
	ctx := testhelper.Context(t)

	const blobSize = 1024 * 1024

	in := make([]byte, blobSize)
	_, err := rand.Read(in)
	require.NoError(t, err)

	var out []byte
	require.NotEqual(t, in, out)

	addr := startServer(
		t,
		func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
			conn, err := OpenSidechannel(context)
			if err != nil {
				return nil, err
			}
			defer conn.Close()

			if _, err = io.CopyN(conn, conn, blobSize); err != nil {
				return nil, err
			}
			return &healthpb.HealthCheckResponse{}, conn.Close()
		},
	)

	conn, registry := dial(t, addr)
	err = call(
		ctx, conn, registry,
		func(conn *ClientConn) error {
			errC := make(chan error, 1)
			go func() {
				var err error
				out, err = io.ReadAll(conn)
				errC <- err
			}()

			_, err = io.Copy(conn, bytes.NewReader(in))
			require.NoError(t, err)
			require.NoError(t, <-errC)

			return nil
		},
	)
	require.NoError(t, err)
	require.Equal(t, in, out, "byte stream works")
}

// Conduct multiple requests with sidechannel included on the same grpc
// connection.
func TestSidechannelConcurrency(t *testing.T) {
	ctx := testhelper.Context(t)

	const concurrency = 10
	const blobSize = 1024 * 1024

	ins := make([][]byte, concurrency)
	for i := 0; i < concurrency; i++ {
		ins[i] = make([]byte, blobSize)
		_, err := rand.Read(ins[i])
		require.NoError(t, err)
	}

	outs := make([][]byte, concurrency)
	for i := 0; i < concurrency; i++ {
		require.NotEqual(t, ins[i], outs[i])
	}

	addr := startServer(
		t,
		func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
			conn, err := OpenSidechannel(context)
			if err != nil {
				return nil, err
			}
			defer conn.Close()

			if _, err = io.CopyN(conn, conn, blobSize); err != nil {
				return nil, err
			}

			return &healthpb.HealthCheckResponse{}, conn.Close()
		},
	)

	conn, registry := dial(t, addr)

	errors := make(chan error, concurrency)

	wg := sync.WaitGroup{}
	for i := 0; i < concurrency; i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()

			err := call(
				ctx, conn, registry,
				func(conn *ClientConn) error {
					errC := make(chan error, 1)
					go func() {
						var err error
						outs[i], err = io.ReadAll(conn)
						errC <- err
					}()

					if _, err := io.Copy(conn, bytes.NewReader(ins[i])); err != nil {
						return err
					}
					if err := <-errC; err != nil {
						return err
					}

					return nil
				},
			)
			errors <- err
		}(i)
	}
	wg.Wait()

	for i := 0; i < concurrency; i++ {
		require.Equal(t, ins[i], outs[i], "byte stream works")
		require.NoError(t, <-errors)
	}
}

func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string {
	t.Helper()

	lm := listenmux.New(insecure.NewCredentials())
	lm.Register(backchannel.NewServerHandshaker(testhelper.SharedLogger(t), backchannel.NewRegistry(), nil))

	opts = append(opts, grpc.Creds(lm))

	s := grpc.NewServer(opts...)
	t.Cleanup(s.Stop)

	handler := &server{testHandler: th}
	healthpb.RegisterHealthServer(s, handler)

	lis, err := net.Listen("tcp", "localhost:0")
	require.NoError(t, err)

	go testhelper.MustServe(t, s, lis)

	return lis.Addr().String()
}

func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) {
	registry := NewRegistry()
	clientHandshaker := NewClientHandshaker(testhelper.SharedLogger(t), registry)
	dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials()))

	conn, err := grpc.Dial(addr, dialOpt)
	require.NoError(t, err)
	t.Cleanup(func() { conn.Close() })

	return conn, registry
}

func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(*ClientConn) error) error {
	client := healthpb.NewHealthClient(conn)

	ctxOut, waiter := RegisterSidechannel(ctx, registry, handler)
	defer func() {
		// We already check the error further down.
		_ = waiter.Close()
	}()

	if _, err := client.Check(ctxOut, &healthpb.HealthCheckRequest{}); err != nil {
		return err
	}

	if err := waiter.Close(); err != nil {
		return err
	}

	return nil
}

type testHandler func(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error)

type server struct {
	healthpb.UnimplementedHealthServer
	testHandler
}

func (s *server) Check(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
	return s.testHandler(context, request)
}
