package tracing import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) type grpcMetadataCarrier struct { md *metadata.MD } func (c *grpcMetadataCarrier) Get(key string) string { values := c.md.Get(key) if len(values) > 0 { return values[0] } return "" } func (c *grpcMetadataCarrier) Set(key string, value string) { c.md.Set(key, value) } func (c *grpcMetadataCarrier) Keys() []string { result := make([]string, 0, c.md.Len()) for key := range *c.md { result = append(result, key) } return result } func extractGRPCTraceInfo(ctx context.Context) context.Context { md, ok := metadata.FromIncomingContext(ctx) if !ok { return ctx } carrier := &grpcMetadataCarrier{ md: &md, } return Propagator.Extract(ctx, carrier) } func setGRPCTraceInfo(ctx context.Context) context.Context { md, ok := metadata.FromOutgoingContext(ctx) if !ok { md = metadata.MD{} } carrier := &grpcMetadataCarrier{ md: &md, } Propagator.Inject(ctx, carrier) return metadata.NewOutgoingContext(ctx, md) } type clientStream struct { originalStream grpc.ClientStream desc *grpc.StreamDesc finished chan<- error done <-chan struct{} } func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, finished chan<- error, done <-chan struct{}) grpc.ClientStream { return &clientStream{ originalStream: originalStream, desc: desc, finished: finished, done: done, } } func (cs *clientStream) Header() (metadata.MD, error) { md, err := cs.originalStream.Header() if err != nil { select { case <-cs.done: case cs.finished <- err: } } return md, err } func (cs *clientStream) Trailer() metadata.MD { return cs.originalStream.Trailer() } func (cs *clientStream) CloseSend() error { err := cs.originalStream.CloseSend() if err != nil { select { case <-cs.done: case cs.finished <- err: } } return err } func (cs *clientStream) Context() context.Context { return cs.originalStream.Context() } func (cs *clientStream) SendMsg(m any) error { err := cs.originalStream.SendMsg(m) if err != nil { select { case <-cs.done: case cs.finished <- err: } } return err } func (cs *clientStream) RecvMsg(m any) error { err := cs.originalStream.RecvMsg(m) if err != nil || !cs.desc.ServerStreams { select { case <-cs.done: case cs.finished <- err: } } return err } type serverStream struct { originalStream grpc.ServerStream ctx context.Context // nolint:containedctx } func newgRPCServerStream(ctx context.Context, originalStream grpc.ServerStream) grpc.ServerStream { return &serverStream{ originalStream: originalStream, ctx: ctx, } } func (ss *serverStream) SetHeader(md metadata.MD) error { return ss.originalStream.SendHeader(md) } func (ss *serverStream) SendHeader(md metadata.MD) error { return ss.originalStream.SendHeader(md) } func (ss *serverStream) SetTrailer(md metadata.MD) { ss.originalStream.SetTrailer(md) } func (ss *serverStream) Context() context.Context { return ss.ctx } func (ss *serverStream) SendMsg(m any) error { return ss.originalStream.SendMsg(m) } func (ss *serverStream) RecvMsg(m any) error { return ss.originalStream.RecvMsg(m) }