diff --git a/pkg/services/object/acl/v2/service.go b/pkg/services/object/acl/v2/service.go index af26a6fd1..3ebac22a5 100644 --- a/pkg/services/object/acl/v2/service.go +++ b/pkg/services/object/acl/v2/service.go @@ -112,7 +112,7 @@ type wrappedGetObjectStream struct { } func (w *wrappedGetObjectStream) Context() context.Context { - return context.WithValue(w.GetObjectStream.Context(), object.RequestContextKey, &object.RequestContext{ + return object.NewRequestContext(w.GetObjectStream.Context(), &object.RequestContext{ Namespace: w.requestInfo.ContainerNamespace(), SenderKey: w.requestInfo.SenderKey(), Role: w.requestInfo.RequestRole(), @@ -135,7 +135,7 @@ type wrappedRangeStream struct { } func (w *wrappedRangeStream) Context() context.Context { - return context.WithValue(w.GetObjectRangeStream.Context(), object.RequestContextKey, &object.RequestContext{ + return object.NewRequestContext(w.GetObjectRangeStream.Context(), &object.RequestContext{ Namespace: w.requestInfo.ContainerNamespace(), SenderKey: w.requestInfo.SenderKey(), Role: w.requestInfo.RequestRole(), @@ -158,7 +158,7 @@ type wrappedSearchStream struct { } func (w *wrappedSearchStream) Context() context.Context { - return context.WithValue(w.SearchStream.Context(), object.RequestContextKey, &object.RequestContext{ + return object.NewRequestContext(w.SearchStream.Context(), &object.RequestContext{ Namespace: w.requestInfo.ContainerNamespace(), SenderKey: w.requestInfo.SenderKey(), Role: w.requestInfo.RequestRole(), @@ -457,7 +457,7 @@ func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetOb } func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context { - return context.WithValue(ctx, object.RequestContextKey, &object.RequestContext{ + return object.NewRequestContext(ctx, &object.RequestContext{ Namespace: reqInfo.ContainerNamespace(), SenderKey: reqInfo.SenderKey(), Role: reqInfo.RequestRole(), diff --git a/pkg/services/object/ape/service.go b/pkg/services/object/ape/service.go index 781f9df4b..3bbad402e 100644 --- a/pkg/services/object/ape/service.go +++ b/pkg/services/object/ape/service.go @@ -3,7 +3,6 @@ package ape import ( "context" "encoding/hex" - "fmt" objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object" "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs" @@ -88,25 +87,13 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error { return g.GetObjectStream.Send(resp) } -func requestContext(ctx context.Context) (*objectSvc.RequestContext, error) { - untyped := ctx.Value(objectSvc.RequestContextKey) - if untyped == nil { - return nil, fmt.Errorf("no key %s in context", objectSvc.RequestContextKey) - } - rc, ok := untyped.(*objectSvc.RequestContext) - if !ok { - return nil, fmt.Errorf("failed cast to RequestContext") - } - return rc, nil -} - func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error { cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID()) if err != nil { return toStatusErr(err) } - reqCtx, err := requestContext(stream.Context()) + reqCtx, err := objectSvc.FromRequestContext(stream.Context()) if err != nil { return toStatusErr(err) } @@ -137,7 +124,7 @@ type putStreamBasicChecker struct { func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error { if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok { - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return toStatusErr(err) } @@ -184,7 +171,7 @@ func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*obj return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } @@ -248,7 +235,7 @@ func (c *Service) Search(request *objectV2.SearchRequest, stream objectSvc.Searc } } - reqCtx, err := requestContext(stream.Context()) + reqCtx, err := objectSvc.FromRequestContext(stream.Context()) if err != nil { return toStatusErr(err) } @@ -273,7 +260,7 @@ func (c *Service) Delete(ctx context.Context, request *objectV2.DeleteRequest) ( return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } @@ -304,7 +291,7 @@ func (c *Service) GetRange(request *objectV2.GetRangeRequest, stream objectSvc.G return toStatusErr(err) } - reqCtx, err := requestContext(stream.Context()) + reqCtx, err := objectSvc.FromRequestContext(stream.Context()) if err != nil { return toStatusErr(err) } @@ -330,7 +317,7 @@ func (c *Service) GetRangeHash(ctx context.Context, request *objectV2.GetRangeHa return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } @@ -365,7 +352,7 @@ func (c *Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequ return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } diff --git a/pkg/services/object/request_context.go b/pkg/services/object/request_context.go index 4b9aa04d1..8d49e3773 100644 --- a/pkg/services/object/request_context.go +++ b/pkg/services/object/request_context.go @@ -1,10 +1,15 @@ package object -import "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl" +import ( + "context" + "fmt" -type RequestContextKeyT struct{} + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl" +) -var RequestContextKey = RequestContextKeyT{} +type requestContextKeyT struct{} + +var requestContextKey = requestContextKeyT{} // RequestContext is a context passed between middleware handlers. type RequestContext struct { @@ -14,3 +19,17 @@ type RequestContext struct { Role acl.Role } + +// NewRequestContext returns a copy of ctx which carries value. +func NewRequestContext(ctx context.Context, value *RequestContext) context.Context { + return context.WithValue(ctx, requestContextKey, value) +} + +// FromRequestContext returns RequestContext value stored in ctx if any. +func FromRequestContext(ctx context.Context) (*RequestContext, error) { + reqCtx, ok := ctx.Value(requestContextKey).(*RequestContext) + if !ok { + return nil, fmt.Errorf("no key %s in context", requestContextKey) + } + return reqCtx, nil +}