package tracing

import (
	"context"

	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

type grpcMetadataCarrier struct {
	md *metadata.MD
}

func (c *grpcMetadataCarrier) Get(key string) string {
	values := c.md.Get(key)
	if len(values) > 0 {
		return values[0]
	}
	return ""
}

func (c *grpcMetadataCarrier) Set(key string, value string) {
	c.md.Set(key, value)
}

func (c *grpcMetadataCarrier) Keys() []string {
	result := make([]string, 0, c.md.Len())
	for key := range *c.md {
		result = append(result, key)
	}
	return result
}

func extractGRPCTraceInfo(ctx context.Context) context.Context {
	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		return ctx
	}
	carrier := &grpcMetadataCarrier{
		md: &md,
	}
	return Propagator.Extract(ctx, carrier)
}

func setGRPCTraceInfo(ctx context.Context) context.Context {
	md, ok := metadata.FromOutgoingContext(ctx)
	if !ok {
		md = metadata.MD{}
	}
	carrier := &grpcMetadataCarrier{
		md: &md,
	}
	Propagator.Inject(ctx, carrier)
	return metadata.NewOutgoingContext(ctx, md)
}

type clientStream struct {
	originalStream grpc.ClientStream
	desc           *grpc.StreamDesc
	finished       chan<- error
	done           <-chan struct{}
}

func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, finished chan<- error, done <-chan struct{}) grpc.ClientStream {
	return &clientStream{
		originalStream: originalStream,
		desc:           desc,
		finished:       finished,
		done:           done,
	}
}

func (cs *clientStream) Header() (metadata.MD, error) {
	md, err := cs.originalStream.Header()
	if err != nil {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	return md, err
}

func (cs *clientStream) Trailer() metadata.MD {
	return cs.originalStream.Trailer()
}

func (cs *clientStream) CloseSend() error {
	err := cs.originalStream.CloseSend()
	if err != nil {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	return err
}

func (cs *clientStream) Context() context.Context {
	return cs.originalStream.Context()
}

func (cs *clientStream) SendMsg(m any) error {
	err := cs.originalStream.SendMsg(m)
	if err != nil {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	return err
}

func (cs *clientStream) RecvMsg(m any) error {
	err := cs.originalStream.RecvMsg(m)
	if err != nil || !cs.desc.ServerStreams {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	return err
}

type serverStream struct {
	originalStream grpc.ServerStream
	ctx            context.Context // nolint:containedctx
}

func newgRPCServerStream(ctx context.Context, originalStream grpc.ServerStream) grpc.ServerStream {
	return &serverStream{
		originalStream: originalStream,
		ctx:            ctx,
	}
}

func (ss *serverStream) SetHeader(md metadata.MD) error {
	return ss.originalStream.SendHeader(md)
}

func (ss *serverStream) SendHeader(md metadata.MD) error {
	return ss.originalStream.SendHeader(md)
}

func (ss *serverStream) SetTrailer(md metadata.MD) {
	ss.originalStream.SetTrailer(md)
}

func (ss *serverStream) Context() context.Context {
	return ss.ctx
}

func (ss *serverStream) SendMsg(m any) error {
	return ss.originalStream.SendMsg(m)
}

func (ss *serverStream) RecvMsg(m any) error {
	return ss.originalStream.RecvMsg(m)
}