frostfs-api-go-pogpp/pkg/tracing/grpc_internal.go
Dmitrii Stepanov 488ee50f9e [#12] tracing: Add gRPC middleware
Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
2023-04-11 10:55:42 +03:00

159 lines
3.2 KiB
Go

package tracing
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type grpcMetadataCarrier struct {
md *metadata.MD
}
func (c *grpcMetadataCarrier) Get(key string) string {
values := c.md.Get(key)
if len(values) > 0 {
return values[0]
}
return ""
}
func (c *grpcMetadataCarrier) Set(key string, value string) {
c.md.Set(key, value)
}
func (c *grpcMetadataCarrier) Keys() []string {
result := make([]string, 0, c.md.Len())
for key := range *c.md {
result = append(result, key)
}
return result
}
func extractGRPCTraceInfo(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx
}
carrier := &grpcMetadataCarrier{
md: &md,
}
return Propagator.Extract(ctx, carrier)
}
func setGRPCTraceInfo(ctx context.Context) context.Context {
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.MD{}
}
carrier := &grpcMetadataCarrier{
md: &md,
}
Propagator.Inject(ctx, carrier)
return metadata.NewOutgoingContext(ctx, md)
}
type clientStream struct {
originalStream grpc.ClientStream
desc *grpc.StreamDesc
finished chan<- error
done <-chan struct{}
}
func newgRPCClientStream(originalStream grpc.ClientStream, desc *grpc.StreamDesc, finished chan<- error, done <-chan struct{}) grpc.ClientStream {
return &clientStream{
originalStream: originalStream,
desc: desc,
finished: finished,
done: done,
}
}
func (cs *clientStream) Header() (metadata.MD, error) {
md, err := cs.originalStream.Header()
if err != nil {
select {
case <-cs.done:
case cs.finished <- err:
}
}
return md, err
}
func (cs *clientStream) Trailer() metadata.MD {
return cs.originalStream.Trailer()
}
func (cs *clientStream) CloseSend() error {
err := cs.originalStream.CloseSend()
if err != nil {
select {
case <-cs.done:
case cs.finished <- err:
}
}
return err
}
func (cs *clientStream) Context() context.Context {
return cs.originalStream.Context()
}
func (cs *clientStream) SendMsg(m interface{}) error {
err := cs.originalStream.SendMsg(m)
if err != nil {
select {
case <-cs.done:
case cs.finished <- err:
}
}
return err
}
func (cs *clientStream) RecvMsg(m interface{}) error {
err := cs.originalStream.RecvMsg(m)
if err != nil || !cs.desc.ServerStreams {
select {
case <-cs.done:
case cs.finished <- err:
}
}
return err
}
type serverStream struct {
originalStream grpc.ServerStream
ctx context.Context
}
func newgRPCServerStream(ctx context.Context, originalStream grpc.ServerStream) grpc.ServerStream {
return &serverStream{
originalStream: originalStream,
ctx: ctx,
}
}
func (ss *serverStream) SetHeader(md metadata.MD) error {
return ss.originalStream.SendHeader(md)
}
func (ss *serverStream) SendHeader(md metadata.MD) error {
return ss.originalStream.SendHeader(md)
}
func (ss *serverStream) SetTrailer(md metadata.MD) {
ss.originalStream.SetTrailer(md)
}
func (ss *serverStream) Context() context.Context {
return ss.ctx
}
func (ss *serverStream) SendMsg(m interface{}) error {
return ss.originalStream.SendMsg(m)
}
func (ss *serverStream) RecvMsg(m interface{}) error {
return ss.originalStream.RecvMsg(m)
}