diff --git a/pkg/services/object/acl/acl.go b/pkg/services/object/acl/acl.go index 921545c8b..d806c75b2 100644 --- a/pkg/services/object/acl/acl.go +++ b/pkg/services/object/acl/acl.go @@ -100,7 +100,7 @@ func (c *Checker) StickyBitCheck(info v2.RequestInfo, owner user.ID) bool { } // CheckEACL is a main check function for extended ACL. -func (c *Checker) CheckEACL(msg any, reqInfo v2.RequestInfo) error { +func (c *Checker) CheckEACL(ctx context.Context, msg any, reqInfo v2.RequestInfo) error { basicACL := reqInfo.BasicACL() if !basicACL.Extendable() { return nil @@ -136,7 +136,7 @@ func (c *Checker) CheckEACL(msg any, reqInfo v2.RequestInfo) error { return err } - hdrSrc, err := c.getHeaderSource(cnr, msg, reqInfo) + hdrSrc, err := c.getHeaderSource(ctx, cnr, msg, reqInfo) if err != nil { return err } @@ -173,7 +173,7 @@ func getRole(reqInfo v2.RequestInfo) eaclSDK.Role { return eaclRole } -func (c *Checker) getHeaderSource(cnr cid.ID, msg any, reqInfo v2.RequestInfo) (eaclSDK.TypedHeaderSource, error) { +func (c *Checker) getHeaderSource(ctx context.Context, cnr cid.ID, msg any, reqInfo v2.RequestInfo) (eaclSDK.TypedHeaderSource, error) { var xHeaderSource eaclV2.XHeaderSource if req, ok := msg.(eaclV2.Request); ok { xHeaderSource = eaclV2.NewRequestXHeaderSource(req) @@ -181,7 +181,7 @@ func (c *Checker) getHeaderSource(cnr cid.ID, msg any, reqInfo v2.RequestInfo) ( xHeaderSource = eaclV2.NewResponseXHeaderSource(msg.(eaclV2.Response), reqInfo.Request().(eaclV2.Request)) } - hdrSrc, err := eaclV2.NewMessageHeaderSource(&localStorage{ls: c.localStorage}, xHeaderSource, cnr, eaclV2.WithOID(reqInfo.ObjectID())) + hdrSrc, err := eaclV2.NewMessageHeaderSource(ctx, &localStorage{ls: c.localStorage}, xHeaderSource, cnr, reqInfo.ObjectID()) if err != nil { return nil, fmt.Errorf("can't parse headers: %w", err) } diff --git a/pkg/services/object/acl/eacl/v2/eacl_test.go b/pkg/services/object/acl/eacl/v2/eacl_test.go index 023b99239..e925fdb1b 100644 --- a/pkg/services/object/acl/eacl/v2/eacl_test.go +++ b/pkg/services/object/acl/eacl/v2/eacl_test.go @@ -103,10 +103,11 @@ func TestHeadRequest(t *testing.T) { newSource := func(t *testing.T) eaclSDK.TypedHeaderSource { hdrSrc, err := NewMessageHeaderSource( + context.TODO(), lStorage, NewRequestXHeaderSource(req), addr.Container(), - WithOID(&id)) + &id) require.NoError(t, err) return hdrSrc } diff --git a/pkg/services/object/acl/eacl/v2/headers.go b/pkg/services/object/acl/eacl/v2/headers.go index 34975e1e6..906ef4d8e 100644 --- a/pkg/services/object/acl/eacl/v2/headers.go +++ b/pkg/services/object/acl/eacl/v2/headers.go @@ -9,6 +9,7 @@ import ( objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object" refsV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs" "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session" + "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" eaclSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/eacl" objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" @@ -16,8 +17,6 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user" ) -type Option func(*cfg) - type cfg struct { storage ObjectStorage @@ -46,24 +45,21 @@ type headerSource struct { incompleteObjectHeaders bool } -func NewMessageHeaderSource(os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, opts ...Option) (eaclSDK.TypedHeaderSource, error) { +func NewMessageHeaderSource(ctx context.Context, os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, objID *oid.ID) (eaclSDK.TypedHeaderSource, error) { cfg := &cfg{ storage: os, cnr: cnrID, + obj: objID, msg: xhs, } - for i := range opts { - opts[i](cfg) - } - if cfg.msg == nil { return nil, errors.New("message is not provided") } var res headerSource - err := cfg.readObjectHeaders(&res) + err := cfg.readObjectHeaders(ctx, &res) if err != nil { return nil, err } @@ -96,18 +92,18 @@ func (x xHeader) Value() string { var errMissingOID = errors.New("object ID is missing") -func (h *cfg) readObjectHeaders(dst *headerSource) error { +func (h *cfg) readObjectHeaders(ctx context.Context, dst *headerSource) error { switch m := h.msg.(type) { default: panic(fmt.Sprintf("unexpected message type %T", h.msg)) case requestXHeaderSource: - return h.readObjectHeadersFromRequestXHeaderSource(m, dst) + return h.readObjectHeadersFromRequestXHeaderSource(ctx, m, dst) case responseXHeaderSource: - return h.readObjectHeadersResponseXHeaderSource(m, dst) + return h.readObjectHeadersResponseXHeaderSource(ctx, m, dst) } } -func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, dst *headerSource) error { +func (h *cfg) readObjectHeadersFromRequestXHeaderSource(ctx context.Context, m requestXHeaderSource, dst *headerSource) error { switch req := m.req.(type) { case *objectV2.GetRequest, @@ -116,7 +112,7 @@ func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, return errMissingOID } - objHeaders, completed := h.localObjectHeaders(h.cnr, h.obj) + objHeaders, completed := h.localObjectHeaders(ctx, h.cnr, h.obj) dst.objectHeaders = objHeaders dst.incompleteObjectHeaders = !completed @@ -154,10 +150,10 @@ func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, return nil } -func (h *cfg) readObjectHeadersResponseXHeaderSource(m responseXHeaderSource, dst *headerSource) error { +func (h *cfg) readObjectHeadersResponseXHeaderSource(ctx context.Context, m responseXHeaderSource, dst *headerSource) error { switch resp := m.resp.(type) { default: - objectHeaders, completed := h.localObjectHeaders(h.cnr, h.obj) + objectHeaders, completed := h.localObjectHeaders(ctx, h.cnr, h.obj) dst.objectHeaders = objectHeaders dst.incompleteObjectHeaders = !completed @@ -198,14 +194,25 @@ func (h *cfg) readObjectHeadersResponseXHeaderSource(m responseXHeaderSource, ds return nil } -func (h *cfg) localObjectHeaders(cnr cid.ID, idObj *oid.ID) ([]eaclSDK.Header, bool) { +func (h *cfg) localObjectHeaders(ctx context.Context, cnr cid.ID, idObj *oid.ID) ([]eaclSDK.Header, bool) { if idObj != nil { var addr oid.Address addr.SetContainer(cnr) addr.SetObject(*idObj) - obj, err := h.storage.Head(context.TODO(), addr) + reqCtx, _ := object.FromRequestContext(ctx) + if reqCtx != nil { + hdr := reqCtx.GetHeader(addr) + if hdr != nil { + return headersFromObject(hdr, cnr, idObj), true + } + } + + obj, err := h.storage.Head(ctx, addr) if err == nil { + if reqCtx != nil { + reqCtx.Header.Store(obj) + } return headersFromObject(obj, cnr, idObj), true } } diff --git a/pkg/services/object/acl/eacl/v2/opts.go b/pkg/services/object/acl/eacl/v2/opts.go deleted file mode 100644 index d91a21c75..000000000 --- a/pkg/services/object/acl/eacl/v2/opts.go +++ /dev/null @@ -1,11 +0,0 @@ -package v2 - -import ( - oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" -) - -func WithOID(v *oid.ID) Option { - return func(c *cfg) { - c.obj = v - } -} diff --git a/pkg/services/object/acl/v2/request.go b/pkg/services/object/acl/v2/request.go index 74279e453..d003a0e6d 100644 --- a/pkg/services/object/acl/v2/request.go +++ b/pkg/services/object/acl/v2/request.go @@ -5,6 +5,7 @@ import ( "fmt" sessionV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session" + "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" @@ -157,3 +158,13 @@ func unmarshalPublicKeyWithOwner(rawKey []byte) (*user.ID, *keys.PublicKey, erro return &idSender, key, nil } + +func (r RequestInfo) toRequestContext() *object.RequestContext { + return &object.RequestContext{ + Namespace: r.ContainerNamespace(), + ContainerOwner: r.ContainerOwner(), + SenderKey: r.SenderKey(), + Role: r.RequestRole(), + SoftAPECheck: r.IsSoftAPECheck(), + } +} diff --git a/pkg/services/object/acl/v2/service.go b/pkg/services/object/acl/v2/service.go index 1f1c275cf..cf53bd513 100644 --- a/pkg/services/object/acl/v2/service.go +++ b/pkg/services/object/acl/v2/service.go @@ -108,24 +108,11 @@ func New(next object.ServiceServer, type wrappedGetObjectStream struct { object.GetObjectStream - requestInfo RequestInfo + requestContext *object.RequestContext } func (w *wrappedGetObjectStream) Context() context.Context { - return context.WithValue(w.GetObjectStream.Context(), object.RequestContextKey, &object.RequestContext{ - Namespace: w.requestInfo.ContainerNamespace(), - ContainerOwner: w.requestInfo.ContainerOwner(), - SenderKey: w.requestInfo.SenderKey(), - Role: w.requestInfo.RequestRole(), - SoftAPECheck: w.requestInfo.IsSoftAPECheck(), - }) -} - -func newWrappedGetObjectStreamStream(getObjectStream object.GetObjectStream, reqInfo RequestInfo) object.GetObjectStream { - return &wrappedGetObjectStream{ - GetObjectStream: getObjectStream, - requestInfo: reqInfo, - } + return object.NewRequestContext(w.GetObjectStream.Context(), w.requestContext) } // wrappedRangeStream propagates RequestContext into GetObjectRangeStream's context. @@ -133,24 +120,11 @@ func newWrappedGetObjectStreamStream(getObjectStream object.GetObjectStream, req type wrappedRangeStream struct { object.GetObjectRangeStream - requestInfo RequestInfo + requestContext *object.RequestContext } func (w *wrappedRangeStream) Context() context.Context { - return context.WithValue(w.GetObjectRangeStream.Context(), object.RequestContextKey, &object.RequestContext{ - Namespace: w.requestInfo.ContainerNamespace(), - ContainerOwner: w.requestInfo.ContainerOwner(), - SenderKey: w.requestInfo.SenderKey(), - Role: w.requestInfo.RequestRole(), - SoftAPECheck: w.requestInfo.IsSoftAPECheck(), - }) -} - -func newWrappedRangeStream(rangeStream object.GetObjectRangeStream, reqInfo RequestInfo) object.GetObjectRangeStream { - return &wrappedRangeStream{ - GetObjectRangeStream: rangeStream, - requestInfo: reqInfo, - } + return object.NewRequestContext(w.GetObjectRangeStream.Context(), w.requestContext) } // wrappedSearchStream propagates RequestContext into SearchStream's context. @@ -158,24 +132,11 @@ func newWrappedRangeStream(rangeStream object.GetObjectRangeStream, reqInfo Requ type wrappedSearchStream struct { object.SearchStream - requestInfo RequestInfo + requestContext *object.RequestContext } func (w *wrappedSearchStream) Context() context.Context { - return context.WithValue(w.SearchStream.Context(), object.RequestContextKey, &object.RequestContext{ - Namespace: w.requestInfo.ContainerNamespace(), - ContainerOwner: w.requestInfo.ContainerOwner(), - SenderKey: w.requestInfo.SenderKey(), - Role: w.requestInfo.RequestRole(), - SoftAPECheck: w.requestInfo.IsSoftAPECheck(), - }) -} - -func newWrappedSearchStream(searchStream object.SearchStream, reqInfo RequestInfo) object.SearchStream { - return &wrappedSearchStream{ - SearchStream: searchStream, - requestInfo: reqInfo, - } + return object.NewRequestContext(w.SearchStream.Context(), w.requestContext) } // Get implements ServiceServer interface, makes ACL checks and calls @@ -222,18 +183,25 @@ func (b Service) Get(request *objectV2.GetRequest, stream object.GetObjectStream reqInfo.obj = obj + reqCtx := reqInfo.toRequestContext() if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) { return basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } + + ctx := object.NewRequestContext(stream.Context(), reqCtx) + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return eACLErr(reqInfo, err) } } return b.next.Get(request, &getStreamBasicChecker{ - GetObjectStream: newWrappedGetObjectStreamStream(stream, reqInfo), - info: reqInfo, - checker: b.checker, + GetObjectStream: &wrappedGetObjectStream{ + GetObjectStream: stream, + requestContext: reqCtx, + }, + info: reqInfo, + checker: b.checker, }) } @@ -291,17 +259,20 @@ func (b Service) Head( reqInfo.obj = obj + ctx = requestContext(ctx, reqInfo) if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) { return nil, basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } + + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } } - resp, err := b.next.Head(requestContext(ctx, reqInfo), request) + resp, err := b.next.Head(ctx, request) if err == nil { - if err = b.checker.CheckEACL(resp, reqInfo); err != nil { + if err = b.checker.CheckEACL(ctx, resp, reqInfo); err != nil { err = eACLErr(reqInfo, err) } } @@ -344,18 +315,25 @@ func (b Service) Search(request *objectV2.SearchRequest, stream object.SearchStr return err } + reqCtx := reqInfo.toRequestContext() if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) { return basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } + + ctx := object.NewRequestContext(stream.Context(), reqCtx) + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return eACLErr(reqInfo, err) } } return b.next.Search(request, &searchStreamBasicChecker{ - checker: b.checker, - SearchStream: newWrappedSearchStream(stream, reqInfo), - info: reqInfo, + checker: b.checker, + SearchStream: &wrappedSearchStream{ + SearchStream: stream, + requestContext: reqCtx, + }, + info: reqInfo, }) } @@ -404,15 +382,15 @@ func (b Service) Delete( reqInfo.obj = obj + ctx = requestContext(ctx, reqInfo) if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) { return nil, basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } } - - return b.next.Delete(requestContext(ctx, reqInfo), request) + return b.next.Delete(ctx, request) } func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetObjectRangeStream) error { @@ -457,29 +435,30 @@ func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetOb reqInfo.obj = obj + reqCtx := reqInfo.toRequestContext() if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) { return basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } + + ctx := object.NewRequestContext(stream.Context(), reqCtx) + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return eACLErr(reqInfo, err) } } return b.next.GetRange(request, &rangeStreamBasicChecker{ - checker: b.checker, - GetObjectRangeStream: newWrappedRangeStream(stream, reqInfo), - info: reqInfo, + checker: b.checker, + GetObjectRangeStream: &wrappedRangeStream{ + GetObjectRangeStream: stream, + requestContext: reqCtx, + }, + info: reqInfo, }) } func requestContext(ctx context.Context, reqInfo RequestInfo) context.Context { - return context.WithValue(ctx, object.RequestContextKey, &object.RequestContext{ - Namespace: reqInfo.ContainerNamespace(), - ContainerOwner: reqInfo.ContainerOwner(), - SenderKey: reqInfo.SenderKey(), - Role: reqInfo.RequestRole(), - SoftAPECheck: reqInfo.IsSoftAPECheck(), - }) + return object.NewRequestContext(ctx, reqInfo.toRequestContext()) } func (b Service) GetRangeHash( @@ -527,15 +506,18 @@ func (b Service) GetRangeHash( reqInfo.obj = obj + ctx = requestContext(ctx, reqInfo) if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) { return nil, basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } + + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } } - return b.next.GetRangeHash(requestContext(ctx, reqInfo), request) + return b.next.GetRangeHash(ctx, request) } func (b Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequest) (*objectV2.PutSingleResponse, error) { @@ -586,16 +568,18 @@ func (b Service) PutSingle(ctx context.Context, request *objectV2.PutSingleReque reqInfo.obj = obj + ctx = requestContext(ctx, reqInfo) if reqInfo.IsSoftAPECheck() { if !b.checker.CheckBasicACL(reqInfo) || !b.checker.StickyBitCheck(reqInfo, idOwner) { return nil, basicACLErr(reqInfo) } - if err := b.checker.CheckEACL(request, reqInfo); err != nil { + + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } } - return b.next.PutSingle(requestContext(ctx, reqInfo), request) + return b.next.PutSingle(ctx, request) } func (p putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error { @@ -706,7 +690,7 @@ func (p putStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PutR func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error { if _, ok := resp.GetBody().GetObjectPart().(*objectV2.GetObjectPartInit); ok { - if err := g.checker.CheckEACL(resp, g.info); err != nil { + if err := g.checker.CheckEACL(g.GetObjectStream.Context(), resp, g.info); err != nil { return eACLErr(g.info, err) } } @@ -715,7 +699,7 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error { } func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error { - if err := g.checker.CheckEACL(resp, g.info); err != nil { + if err := g.checker.CheckEACL(g.GetObjectRangeStream.Context(), resp, g.info); err != nil { return eACLErr(g.info, err) } @@ -723,7 +707,7 @@ func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error { } func (g *searchStreamBasicChecker) Send(resp *objectV2.SearchResponse) error { - if err := g.checker.CheckEACL(resp, g.info); err != nil { + if err := g.checker.CheckEACL(g.SearchStream.Context(), resp, g.info); err != nil { return eACLErr(g.info, err) } diff --git a/pkg/services/object/acl/v2/types.go b/pkg/services/object/acl/v2/types.go index 061cd26b6..cd139dd7f 100644 --- a/pkg/services/object/acl/v2/types.go +++ b/pkg/services/object/acl/v2/types.go @@ -1,6 +1,8 @@ package v2 import ( + "context" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user" ) @@ -12,7 +14,7 @@ type ACLChecker interface { CheckBasicACL(RequestInfo) bool // CheckEACL must return non-nil error if request // doesn't pass extended ACL validation. - CheckEACL(any, RequestInfo) error + CheckEACL(context.Context, any, RequestInfo) error // StickyBitCheck must return true only if sticky bit // is disabled or enabled but request contains correct // owner field. diff --git a/pkg/services/object/ape/service.go b/pkg/services/object/ape/service.go index 0c203209d..b97efb002 100644 --- a/pkg/services/object/ape/service.go +++ b/pkg/services/object/ape/service.go @@ -3,8 +3,6 @@ package ape import ( "context" "encoding/hex" - "errors" - "fmt" objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object" "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs" @@ -18,8 +16,6 @@ import ( nativeschema "git.frostfs.info/TrueCloudLab/policy-engine/schema/native" ) -var errFailedToCastToRequestContext = errors.New("failed cast to RequestContext") - type Service struct { log *logger.Logger @@ -101,25 +97,13 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error { return g.GetObjectStream.Send(resp) } -func requestContext(ctx context.Context) (*objectSvc.RequestContext, error) { - untyped := ctx.Value(objectSvc.RequestContextKey) - if untyped == nil { - return nil, fmt.Errorf("no key %s in context", objectSvc.RequestContextKey) - } - rc, ok := untyped.(*objectSvc.RequestContext) - if !ok { - return nil, errFailedToCastToRequestContext - } - return rc, nil -} - func (c *Service) Get(request *objectV2.GetRequest, stream objectSvc.GetObjectStream) error { cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID()) if err != nil { return toStatusErr(err) } - reqCtx, err := requestContext(stream.Context()) + reqCtx, err := objectSvc.FromRequestContext(stream.Context()) if err != nil { return toStatusErr(err) } @@ -156,7 +140,7 @@ type putStreamBasicChecker struct { func (p *putStreamBasicChecker) Send(ctx context.Context, request *objectV2.PutRequest) error { if partInit, ok := request.GetBody().GetObjectPart().(*objectV2.PutObjectPartInit); ok { - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return toStatusErr(err) } @@ -205,7 +189,7 @@ func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*obj return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } @@ -219,6 +203,7 @@ func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*obj SenderKey: hex.EncodeToString(reqCtx.SenderKey), ContainerOwner: reqCtx.ContainerOwner, SoftAPECheck: reqCtx.SoftAPECheck, + Header: reqCtx.Header.Load().ToV2().GetHeader(), }) if err != nil { return nil, toStatusErr(err) @@ -273,7 +258,7 @@ func (c *Service) Search(request *objectV2.SearchRequest, stream objectSvc.Searc } } - reqCtx, err := requestContext(stream.Context()) + reqCtx, err := objectSvc.FromRequestContext(stream.Context()) if err != nil { return toStatusErr(err) } @@ -300,7 +285,7 @@ func (c *Service) Delete(ctx context.Context, request *objectV2.DeleteRequest) ( return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } @@ -333,7 +318,7 @@ func (c *Service) GetRange(request *objectV2.GetRangeRequest, stream objectSvc.G return toStatusErr(err) } - reqCtx, err := requestContext(stream.Context()) + reqCtx, err := objectSvc.FromRequestContext(stream.Context()) if err != nil { return toStatusErr(err) } @@ -361,7 +346,7 @@ func (c *Service) GetRangeHash(ctx context.Context, request *objectV2.GetRangeHa return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } @@ -398,7 +383,7 @@ func (c *Service) PutSingle(ctx context.Context, request *objectV2.PutSingleRequ return nil, err } - reqCtx, err := requestContext(ctx) + reqCtx, err := objectSvc.FromRequestContext(ctx) if err != nil { return nil, err } diff --git a/pkg/services/object/get/local.go b/pkg/services/object/get/local.go index 257465019..2d62361e1 100644 --- a/pkg/services/object/get/local.go +++ b/pkg/services/object/get/local.go @@ -5,6 +5,7 @@ import ( "errors" "git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs" + "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object" "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing" apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" @@ -50,6 +51,12 @@ func (r *request) executeLocal(ctx context.Context) { func (r *request) get(ctx context.Context) (*objectSDK.Object, error) { if r.headOnly() { + reqCtx, _ := object.FromRequestContext(ctx) + if reqCtx != nil && !r.isRaw() { + if hdr := reqCtx.GetHeader(r.address()); hdr != nil { + return hdr, nil + } + } return r.localStorage.Head(ctx, r.address(), r.isRaw()) } if rng := r.ctxRange(); rng != nil { diff --git a/pkg/services/object/request_context.go b/pkg/services/object/request_context.go index 6a0965b40..23d904486 100644 --- a/pkg/services/object/request_context.go +++ b/pkg/services/object/request_context.go @@ -1,13 +1,20 @@ package object import ( + "context" + "fmt" + "sync/atomic" + + "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl" + objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" + oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user" ) -type RequestContextKeyT struct{} +type requestContextKeyT struct{} -var RequestContextKey = RequestContextKeyT{} +var requestContextKey = requestContextKeyT{} // RequestContext is a context passed between middleware handlers. type RequestContext struct { @@ -20,4 +27,38 @@ type RequestContext struct { Role acl.Role SoftAPECheck bool + + Header atomic.Pointer[objectSDK.Object] +} + +// NewRequestContext returns a copy of ctx which carries value. +func NewRequestContext(ctx context.Context, value *RequestContext) context.Context { + return context.WithValue(ctx, requestContextKey, value) +} + +// FromRequestContext returns RequestContext value stored in ctx if any. +func FromRequestContext(ctx context.Context) (*RequestContext, error) { + reqCtx, ok := ctx.Value(requestContextKey).(*RequestContext) + if !ok { + return nil, fmt.Errorf("no key %s in context", requestContextKey) + } + return reqCtx, nil +} + +// GetHeader returns header if it is present and matches cid + oid pair. +func (r *RequestContext) GetHeader(addr oid.Address) *objectSDK.Object { + if r == nil { + return nil + } + + hdr := r.Header.Load() + if hdr == nil { + return nil + } + + storedAddr := object.AddressOf(hdr) + if addr.Equals(storedAddr) { + return hdr + } + return nil }