package grpc import ( "context" "fmt" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) type clientStream struct { originalStream grpc.ClientStream desc *grpc.StreamDesc span trace.Span finished chan<- error done <-chan struct{} } func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, span trace.Span, finished chan<- error, done <-chan struct{}) grpc.ClientStream { return &clientStream{ originalStream: originalStream, desc: desc, span: span, finished: finished, done: done, } } func (cs *clientStream) Header() (metadata.MD, error) { md, err := cs.originalStream.Header() if err != nil { select { case <-cs.done: case cs.finished <- err: } } return md, err } func (cs *clientStream) Trailer() metadata.MD { return cs.originalStream.Trailer() } func (cs *clientStream) CloseSend() error { cs.span.AddEvent("client.stream.close.send.start") err := cs.originalStream.CloseSend() if err != nil { select { case <-cs.done: case cs.finished <- err: } } cs.span.AddEvent("client.stream.close.send.finish") return err } func (cs *clientStream) Context() context.Context { return cs.originalStream.Context() } func (cs *clientStream) SendMsg(m any) error { cs.span.AddEvent("client.stream.send.msg.start", trace.WithAttributes( attribute.String("message.type", fmt.Sprintf("%T", m))), ) err := cs.originalStream.SendMsg(m) if err != nil { select { case <-cs.done: case cs.finished <- err: } } cs.span.AddEvent("client.stream.send.msg.finish", trace.WithAttributes( attribute.String("message.type", fmt.Sprintf("%T", m))), ) return err } func (cs *clientStream) RecvMsg(m any) error { cs.span.AddEvent("client.stream.receive.msg.start", trace.WithAttributes( attribute.String("message.type", fmt.Sprintf("%T", m))), ) err := cs.originalStream.RecvMsg(m) if err != nil || !cs.desc.ServerStreams { select { case <-cs.done: case cs.finished <- err: } } cs.span.AddEvent("client.stream.receive.msg.finish", trace.WithAttributes( attribute.String("message.type", fmt.Sprintf("%T", m))), ) return err }