package tagging import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) const ( ioTagHeader = "x-frostfs-io-tag" ) // NewUnaryClientInteceptor creates new gRPC unary interceptor to set an IO tag to gRPC metadata. func NewUnaryClientInteceptor() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { return invoker(setIOTagToGRPCMetadata(ctx), method, req, reply, cc, opts...) } } // NewStreamClientInterceptor creates new gRPC stream interceptor to set an IO tag to gRPC metadata. func NewStreamClientInterceptor() grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { return streamer(setIOTagToGRPCMetadata(ctx), desc, cc, method, opts...) } } // NewUnaryServerInterceptor creates new gRPC unary interceptor to extract an IO tag to gRPC metadata. func NewUnaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { return handler(extractIOTagFromGRPCMetadata(ctx), req) } } // NewStreamServerInterceptor creates new gRPC stream interceptor to extract an IO tag to gRPC metadata. func NewStreamServerInterceptor() grpc.StreamServerInterceptor { return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { return handler(srv, &serverStream{origin: ss}) } } func setIOTagToGRPCMetadata(ctx context.Context) context.Context { ioTag, ok := IOTagFromContext(ctx) if !ok { return ctx } md, ok := metadata.FromOutgoingContext(ctx) if !ok { md = metadata.MD{} } md.Set(ioTagHeader, ioTag) return metadata.NewOutgoingContext(ctx, md) } func extractIOTagFromGRPCMetadata(ctx context.Context) context.Context { md, ok := metadata.FromIncomingContext(ctx) if !ok { return ctx } values := md.Get(ioTagHeader) if len(values) > 0 { return ContextWithIOTag(ctx, values[0]) } return ctx } var _ grpc.ServerStream = &serverStream{} type serverStream struct { origin grpc.ServerStream } func (s *serverStream) Context() context.Context { return extractIOTagFromGRPCMetadata(s.origin.Context()) } func (s *serverStream) RecvMsg(m any) error { return s.origin.RecvMsg(m) } func (s *serverStream) SendHeader(md metadata.MD) error { return s.origin.SendHeader(md) } func (s *serverStream) SendMsg(m any) error { return s.origin.SendMsg(m) } func (s *serverStream) SetHeader(md metadata.MD) error { return s.origin.SetHeader(md) } func (s *serverStream) SetTrailer(md metadata.MD) { s.origin.SetTrailer(md) }