package grpc

import (
	"context"
	"fmt"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

type clientStream struct {
	originalStream grpc.ClientStream
	desc           *grpc.StreamDesc
	span           trace.Span
	finished       chan<- error
	done           <-chan struct{}
}

func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, span trace.Span, finished chan<- error, done <-chan struct{}) grpc.ClientStream {
	return &clientStream{
		originalStream: originalStream,
		desc:           desc,
		span:           span,
		finished:       finished,
		done:           done,
	}
}

func (cs *clientStream) Header() (metadata.MD, error) {
	md, err := cs.originalStream.Header()
	if err != nil {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	return md, err
}

func (cs *clientStream) Trailer() metadata.MD {
	return cs.originalStream.Trailer()
}

func (cs *clientStream) CloseSend() error {
	cs.span.AddEvent("client.stream.close.send.start")
	err := cs.originalStream.CloseSend()
	if err != nil {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	cs.span.AddEvent("client.stream.close.send.finish")
	return err
}

func (cs *clientStream) Context() context.Context {
	return cs.originalStream.Context()
}

func (cs *clientStream) SendMsg(m any) error {
	cs.span.AddEvent("client.stream.send.msg.start", trace.WithAttributes(
		attribute.String("message.type", fmt.Sprintf("%T", m))),
	)
	err := cs.originalStream.SendMsg(m)
	if err != nil {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	cs.span.AddEvent("client.stream.send.msg.finish", trace.WithAttributes(
		attribute.String("message.type", fmt.Sprintf("%T", m))),
	)
	return err
}

func (cs *clientStream) RecvMsg(m any) error {
	cs.span.AddEvent("client.stream.receive.msg.start", trace.WithAttributes(
		attribute.String("message.type", fmt.Sprintf("%T", m))),
	)
	err := cs.originalStream.RecvMsg(m)
	if err != nil || !cs.desc.ServerStreams {
		select {
		case <-cs.done:
		case cs.finished <- err:
		}
	}
	cs.span.AddEvent("client.stream.receive.msg.finish", trace.WithAttributes(
		attribute.String("message.type", fmt.Sprintf("%T", m))),
	)
	return err
}