159 lines
3.2 KiB
Go
159 lines
3.2 KiB
Go
package tracing
|
|
|
|
import (
|
|
"context"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
)
|
|
|
|
type grpcMetadataCarrier struct {
|
|
md *metadata.MD
|
|
}
|
|
|
|
func (c *grpcMetadataCarrier) Get(key string) string {
|
|
values := c.md.Get(key)
|
|
if len(values) > 0 {
|
|
return values[0]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (c *grpcMetadataCarrier) Set(key string, value string) {
|
|
c.md.Set(key, value)
|
|
}
|
|
|
|
func (c *grpcMetadataCarrier) Keys() []string {
|
|
result := make([]string, 0, c.md.Len())
|
|
for key := range *c.md {
|
|
result = append(result, key)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func extractGRPCTraceInfo(ctx context.Context) context.Context {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return ctx
|
|
}
|
|
carrier := &grpcMetadataCarrier{
|
|
md: &md,
|
|
}
|
|
return Propagator.Extract(ctx, carrier)
|
|
}
|
|
|
|
func setGRPCTraceInfo(ctx context.Context) context.Context {
|
|
md, ok := metadata.FromOutgoingContext(ctx)
|
|
if !ok {
|
|
md = metadata.MD{}
|
|
}
|
|
carrier := &grpcMetadataCarrier{
|
|
md: &md,
|
|
}
|
|
Propagator.Inject(ctx, carrier)
|
|
return metadata.NewOutgoingContext(ctx, md)
|
|
}
|
|
|
|
type clientStream struct {
|
|
originalStream grpc.ClientStream
|
|
desc *grpc.StreamDesc
|
|
finished chan<- error
|
|
done <-chan struct{}
|
|
}
|
|
|
|
func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, finished chan<- error, done <-chan struct{}) grpc.ClientStream {
|
|
return &clientStream{
|
|
originalStream: originalStream,
|
|
desc: desc,
|
|
finished: finished,
|
|
done: done,
|
|
}
|
|
}
|
|
|
|
func (cs *clientStream) Header() (metadata.MD, error) {
|
|
md, err := cs.originalStream.Header()
|
|
if err != nil {
|
|
select {
|
|
case <-cs.done:
|
|
case cs.finished <- err:
|
|
}
|
|
}
|
|
return md, err
|
|
}
|
|
|
|
func (cs *clientStream) Trailer() metadata.MD {
|
|
return cs.originalStream.Trailer()
|
|
}
|
|
|
|
func (cs *clientStream) CloseSend() error {
|
|
err := cs.originalStream.CloseSend()
|
|
if err != nil {
|
|
select {
|
|
case <-cs.done:
|
|
case cs.finished <- err:
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (cs *clientStream) Context() context.Context {
|
|
return cs.originalStream.Context()
|
|
}
|
|
|
|
func (cs *clientStream) SendMsg(m any) error {
|
|
err := cs.originalStream.SendMsg(m)
|
|
if err != nil {
|
|
select {
|
|
case <-cs.done:
|
|
case cs.finished <- err:
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (cs *clientStream) RecvMsg(m any) error {
|
|
err := cs.originalStream.RecvMsg(m)
|
|
if err != nil || !cs.desc.ServerStreams {
|
|
select {
|
|
case <-cs.done:
|
|
case cs.finished <- err:
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
type serverStream struct {
|
|
originalStream grpc.ServerStream
|
|
ctx context.Context // nolint:containedctx
|
|
}
|
|
|
|
func newgRPCServerStream(ctx context.Context, originalStream grpc.ServerStream) grpc.ServerStream {
|
|
return &serverStream{
|
|
originalStream: originalStream,
|
|
ctx: ctx,
|
|
}
|
|
}
|
|
|
|
func (ss *serverStream) SetHeader(md metadata.MD) error {
|
|
return ss.originalStream.SendHeader(md)
|
|
}
|
|
|
|
func (ss *serverStream) SendHeader(md metadata.MD) error {
|
|
return ss.originalStream.SendHeader(md)
|
|
}
|
|
|
|
func (ss *serverStream) SetTrailer(md metadata.MD) {
|
|
ss.originalStream.SetTrailer(md)
|
|
}
|
|
|
|
func (ss *serverStream) Context() context.Context {
|
|
return ss.ctx
|
|
}
|
|
|
|
func (ss *serverStream) SendMsg(m any) error {
|
|
return ss.originalStream.SendMsg(m)
|
|
}
|
|
|
|
func (ss *serverStream) RecvMsg(m any) error {
|
|
return ss.originalStream.RecvMsg(m)
|
|
}
|