From 488ee50f9e6968aa7e23a195de1a89492857a66c Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Fri, 17 Mar 2023 15:54:33 +0300 Subject: [PATCH] [#12] tracing: Add gRPC middleware Signed-off-by: Dmitrii Stepanov --- pkg/tracing/grpc.go | 136 ++++++++++++++++++++++++++++++ pkg/tracing/grpc_internal.go | 159 +++++++++++++++++++++++++++++++++++ rpc/client/client.go | 10 +++ 3 files changed, 305 insertions(+) create mode 100644 pkg/tracing/grpc.go create mode 100644 pkg/tracing/grpc_internal.go diff --git a/pkg/tracing/grpc.go b/pkg/tracing/grpc.go new file mode 100644 index 0000000..cea59b0 --- /dev/null +++ b/pkg/tracing/grpc.go @@ -0,0 +1,136 @@ +package tracing + +import ( + "context" + "io" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" + grpc_codes "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// NewGRPCUnaryClientInteceptor creates new gRPC unary interceptor to save gRPC client traces. +func NewGRPCUnaryClientInteceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx, span := startClientSpan(ctx, cc, method) + defer span.End() + + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + grpcStatus, _ := status.FromError(err) + span.SetStatus(codes.Error, grpcStatus.Message()) + span.SetAttributes(semconv.RPCGRPCStatusCodeKey.Int64(int64(grpcStatus.Code()))) + } else { + span.SetStatus(codes.Ok, "") + span.SetAttributes(semconv.RPCGRPCStatusCodeKey.Int64(int64(grpc_codes.OK))) + } + + return err + } +} + +// NewGRPCStreamClientInterceptor creates new gRPC stream interceptor to save gRPC client traces. +func NewGRPCStreamClientInterceptor() grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ctx, span := startClientSpan(ctx, cc, method) + str, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + grpcStatus, _ := status.FromError(err) + span.SetStatus(codes.Error, grpcStatus.Message()) + span.SetAttributes(semconv.RPCGRPCStatusCodeKey.Int64(int64(grpcStatus.Code()))) + span.End() + return str, err + } + + finished := make(chan error) + done := make(chan struct{}) + strWrp := newgRPCClientStream(str, desc, finished, done) + + go func() { + defer close(finished) + defer close(done) + defer span.End() + + select { + case err := <-finished: + if err == nil || err == io.EOF { + setGRPCSpanStatus(span, nil) + } else { + setGRPCSpanStatus(span, err) + } + return + case <-ctx.Done(): + setGRPCSpanStatus(span, ctx.Err()) + return + } + }() + + return strWrp, nil + } +} + +// NewGRPCUnaryServerInterceptor creates new gRPC unary interceptor to save gRPC server traces. +func NewGRPCUnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + ctx = extractGRPCTraceInfo(ctx) + var span trace.Span + ctx, span = StartSpanFromContext(ctx, info.FullMethod, + trace.WithAttributes( + semconv.RPCSystemGRPC, + semconv.RPCMethod(info.FullMethod), + ), + trace.WithSpanKind(trace.SpanKindServer)) + defer span.End() + + resp, err = handler(ctx, req) + + setGRPCSpanStatus(span, err) + return + } +} + +// NewGRPCStreamServerInterceptor creates new gRPC stream interceptor to save gRPC server traces. +func NewGRPCStreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := extractGRPCTraceInfo(ss.Context()) + var span trace.Span + ctx, span = StartSpanFromContext(ctx, info.FullMethod, + trace.WithAttributes( + semconv.RPCSystemGRPC, + semconv.RPCMethod(info.FullMethod), + ), + trace.WithSpanKind(trace.SpanKindServer)) + defer span.End() + + err := handler(srv, newgRPCServerStream(ctx, ss)) + + setGRPCSpanStatus(span, err) + return err + } +} + +func startClientSpan(ctx context.Context, cc *grpc.ClientConn, method string) (context.Context, trace.Span) { + ctx, span := StartSpanFromContext(ctx, method, trace.WithAttributes( + semconv.RPCSystemGRPC, + semconv.RPCMethod(method), + attribute.String("rpc.grpc.target", cc.Target())), + trace.WithSpanKind(trace.SpanKindClient), + ) + ctx = setGRPCTraceInfo(ctx) + return ctx, span +} + +func setGRPCSpanStatus(span trace.Span, err error) { + if err != nil { + grpcStatus, _ := status.FromError(err) + span.SetStatus(codes.Error, grpcStatus.Message()) + span.SetAttributes(semconv.RPCGRPCStatusCodeKey.Int64(int64(grpcStatus.Code()))) + } else { + span.SetStatus(codes.Ok, "") + span.SetAttributes(semconv.RPCGRPCStatusCodeKey.Int64(int64(grpc_codes.OK))) + } +} diff --git a/pkg/tracing/grpc_internal.go b/pkg/tracing/grpc_internal.go new file mode 100644 index 0000000..25cf026 --- /dev/null +++ b/pkg/tracing/grpc_internal.go @@ -0,0 +1,159 @@ +package tracing + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type grpcMetadataCarrier struct { + md *metadata.MD +} + +func (c *grpcMetadataCarrier) Get(key string) string { + values := c.md.Get(key) + if len(values) > 0 { + return values[0] + } + return "" +} + +func (c *grpcMetadataCarrier) Set(key string, value string) { + c.md.Set(key, value) +} + +func (c *grpcMetadataCarrier) Keys() []string { + result := make([]string, 0, c.md.Len()) + for key := range *c.md { + result = append(result, key) + } + return result +} + +func extractGRPCTraceInfo(ctx context.Context) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx + } + carrier := &grpcMetadataCarrier{ + md: &md, + } + return Propagator.Extract(ctx, carrier) +} + +func setGRPCTraceInfo(ctx context.Context) context.Context { + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + md = metadata.MD{} + } + carrier := &grpcMetadataCarrier{ + md: &md, + } + Propagator.Inject(ctx, carrier) + return metadata.NewOutgoingContext(ctx, md) +} + +type clientStream struct { + originalStream grpc.ClientStream + desc *grpc.StreamDesc + finished chan<- error + done <-chan struct{} +} + +func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, finished chan<- error, done <-chan struct{}) grpc.ClientStream { + return &clientStream{ + originalStream: originalStream, + desc: desc, + finished: finished, + done: done, + } +} + +func (cs *clientStream) Header() (metadata.MD, error) { + md, err := cs.originalStream.Header() + if err != nil { + select { + case <-cs.done: + case cs.finished <- err: + } + } + return md, err +} + +func (cs *clientStream) Trailer() metadata.MD { + return cs.originalStream.Trailer() +} + +func (cs *clientStream) CloseSend() error { + err := cs.originalStream.CloseSend() + if err != nil { + select { + case <-cs.done: + case cs.finished <- err: + } + } + return err +} + +func (cs *clientStream) Context() context.Context { + return cs.originalStream.Context() +} + +func (cs *clientStream) SendMsg(m interface{}) error { + err := cs.originalStream.SendMsg(m) + if err != nil { + select { + case <-cs.done: + case cs.finished <- err: + } + } + return err +} + +func (cs *clientStream) RecvMsg(m interface{}) error { + err := cs.originalStream.RecvMsg(m) + if err != nil || !cs.desc.ServerStreams { + select { + case <-cs.done: + case cs.finished <- err: + } + } + return err +} + +type serverStream struct { + originalStream grpc.ServerStream + ctx context.Context +} + +func newgRPCServerStream(ctx context.Context, originalStream grpc.ServerStream) grpc.ServerStream { + return &serverStream{ + originalStream: originalStream, + ctx: ctx, + } +} + +func (ss *serverStream) SetHeader(md metadata.MD) error { + return ss.originalStream.SendHeader(md) +} + +func (ss *serverStream) SendHeader(md metadata.MD) error { + return ss.originalStream.SendHeader(md) +} + +func (ss *serverStream) SetTrailer(md metadata.MD) { + ss.originalStream.SetTrailer(md) +} + +func (ss *serverStream) Context() context.Context { + return ss.ctx +} + +func (ss *serverStream) SendMsg(m interface{}) error { + return ss.originalStream.SendMsg(m) +} + +func (ss *serverStream) RecvMsg(m interface{}) error { + return ss.originalStream.RecvMsg(m) +} diff --git a/rpc/client/client.go b/rpc/client/client.go index 7e914db..4d12849 100644 --- a/rpc/client/client.go +++ b/rpc/client/client.go @@ -1,6 +1,7 @@ package client import ( + "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/pkg/tracing" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -24,5 +25,14 @@ func New(opts ...Option) *Client { c.grpcDialOpts = append(c.grpcDialOpts, grpc.WithTransportCredentials(credentials.NewTLS(c.tlsCfg))) } + c.grpcDialOpts = append(c.grpcDialOpts, + grpc.WithChainUnaryInterceptor( + tracing.NewGRPCUnaryClientInteceptor(), + ), + grpc.WithChainStreamInterceptor( + tracing.NewGRPCStreamClientInterceptor(), + ), + ) + return &c }