From 5b0b70b6ac60791792866ca8f2af7aafb256d0dc Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 9 Feb 2024 13:55:03 +0300 Subject: [PATCH] aclsvc: Provide context to CheckEACL Remove context.TODO() and allow passing RequestContext struct later. Signed-off-by: Evgenii Stratonikov --- pkg/services/object/acl/acl.go | 8 +++---- pkg/services/object/acl/eacl/v2/eacl_test.go | 1 + pkg/services/object/acl/eacl/v2/headers.go | 22 ++++++++++---------- pkg/services/object/acl/v2/service.go | 22 ++++++++++---------- pkg/services/object/acl/v2/types.go | 4 +++- 5 files changed, 30 insertions(+), 27 deletions(-) diff --git a/pkg/services/object/acl/acl.go b/pkg/services/object/acl/acl.go index fe58f07b9..d806c75b2 100644 --- a/pkg/services/object/acl/acl.go +++ b/pkg/services/object/acl/acl.go @@ -100,7 +100,7 @@ func (c *Checker) StickyBitCheck(info v2.RequestInfo, owner user.ID) bool { } // CheckEACL is a main check function for extended ACL. -func (c *Checker) CheckEACL(msg any, reqInfo v2.RequestInfo) error { +func (c *Checker) CheckEACL(ctx context.Context, msg any, reqInfo v2.RequestInfo) error { basicACL := reqInfo.BasicACL() if !basicACL.Extendable() { return nil @@ -136,7 +136,7 @@ func (c *Checker) CheckEACL(msg any, reqInfo v2.RequestInfo) error { return err } - hdrSrc, err := c.getHeaderSource(cnr, msg, reqInfo) + hdrSrc, err := c.getHeaderSource(ctx, cnr, msg, reqInfo) if err != nil { return err } @@ -173,7 +173,7 @@ func getRole(reqInfo v2.RequestInfo) eaclSDK.Role { return eaclRole } -func (c *Checker) getHeaderSource(cnr cid.ID, msg any, reqInfo v2.RequestInfo) (eaclSDK.TypedHeaderSource, error) { +func (c *Checker) getHeaderSource(ctx context.Context, cnr cid.ID, msg any, reqInfo v2.RequestInfo) (eaclSDK.TypedHeaderSource, error) { var xHeaderSource eaclV2.XHeaderSource if req, ok := msg.(eaclV2.Request); ok { xHeaderSource = eaclV2.NewRequestXHeaderSource(req) @@ -181,7 +181,7 @@ func (c *Checker) getHeaderSource(cnr cid.ID, msg any, reqInfo v2.RequestInfo) ( xHeaderSource = eaclV2.NewResponseXHeaderSource(msg.(eaclV2.Response), reqInfo.Request().(eaclV2.Request)) } - hdrSrc, err := eaclV2.NewMessageHeaderSource(&localStorage{ls: c.localStorage}, xHeaderSource, cnr, reqInfo.ObjectID()) + hdrSrc, err := eaclV2.NewMessageHeaderSource(ctx, &localStorage{ls: c.localStorage}, xHeaderSource, cnr, reqInfo.ObjectID()) if err != nil { return nil, fmt.Errorf("can't parse headers: %w", err) } diff --git a/pkg/services/object/acl/eacl/v2/eacl_test.go b/pkg/services/object/acl/eacl/v2/eacl_test.go index b5dfc9b82..e925fdb1b 100644 --- a/pkg/services/object/acl/eacl/v2/eacl_test.go +++ b/pkg/services/object/acl/eacl/v2/eacl_test.go @@ -103,6 +103,7 @@ func TestHeadRequest(t *testing.T) { newSource := func(t *testing.T) eaclSDK.TypedHeaderSource { hdrSrc, err := NewMessageHeaderSource( + context.TODO(), lStorage, NewRequestXHeaderSource(req), addr.Container(), diff --git a/pkg/services/object/acl/eacl/v2/headers.go b/pkg/services/object/acl/eacl/v2/headers.go index 378481357..a830e9f21 100644 --- a/pkg/services/object/acl/eacl/v2/headers.go +++ b/pkg/services/object/acl/eacl/v2/headers.go @@ -44,7 +44,7 @@ type headerSource struct { incompleteObjectHeaders bool } -func NewMessageHeaderSource(os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, objID *oid.ID) (eaclSDK.TypedHeaderSource, error) { +func NewMessageHeaderSource(ctx context.Context, os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, objID *oid.ID) (eaclSDK.TypedHeaderSource, error) { cfg := &cfg{ storage: os, cnr: cnrID, @@ -58,7 +58,7 @@ func NewMessageHeaderSource(os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, o var res headerSource - err := cfg.readObjectHeaders(&res) + err := cfg.readObjectHeaders(ctx, &res) if err != nil { return nil, err } @@ -91,18 +91,18 @@ func (x xHeader) Value() string { var errMissingOID = errors.New("object ID is missing") -func (h *cfg) readObjectHeaders(dst *headerSource) error { +func (h *cfg) readObjectHeaders(ctx context.Context, dst *headerSource) error { switch m := h.msg.(type) { default: panic(fmt.Sprintf("unexpected message type %T", h.msg)) case requestXHeaderSource: - return h.readObjectHeadersFromRequestXHeaderSource(m, dst) + return h.readObjectHeadersFromRequestXHeaderSource(ctx, m, dst) case responseXHeaderSource: - return h.readObjectHeadersResponseXHeaderSource(m, dst) + return h.readObjectHeadersResponseXHeaderSource(ctx, m, dst) } } -func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, dst *headerSource) error { +func (h *cfg) readObjectHeadersFromRequestXHeaderSource(ctx context.Context, m requestXHeaderSource, dst *headerSource) error { switch req := m.req.(type) { case *objectV2.GetRequest, @@ -111,7 +111,7 @@ func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, return errMissingOID } - objHeaders, completed := h.localObjectHeaders(h.cnr, h.obj) + objHeaders, completed := h.localObjectHeaders(ctx, h.cnr, h.obj) dst.objectHeaders = objHeaders dst.incompleteObjectHeaders = !completed @@ -149,10 +149,10 @@ func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, return nil } -func (h *cfg) readObjectHeadersResponseXHeaderSource(m responseXHeaderSource, dst *headerSource) error { +func (h *cfg) readObjectHeadersResponseXHeaderSource(ctx context.Context, m responseXHeaderSource, dst *headerSource) error { switch resp := m.resp.(type) { default: - objectHeaders, completed := h.localObjectHeaders(h.cnr, h.obj) + objectHeaders, completed := h.localObjectHeaders(ctx, h.cnr, h.obj) dst.objectHeaders = objectHeaders dst.incompleteObjectHeaders = !completed @@ -193,13 +193,13 @@ func (h *cfg) readObjectHeadersResponseXHeaderSource(m responseXHeaderSource, ds return nil } -func (h *cfg) localObjectHeaders(cnr cid.ID, idObj *oid.ID) ([]eaclSDK.Header, bool) { +func (h *cfg) localObjectHeaders(ctx context.Context, cnr cid.ID, idObj *oid.ID) ([]eaclSDK.Header, bool) { if idObj != nil { var addr oid.Address addr.SetContainer(cnr) addr.SetObject(*idObj) - obj, err := h.storage.Head(context.TODO(), addr) + obj, err := h.storage.Head(ctx, addr) if err == nil { return headersFromObject(obj, cnr, idObj), true } diff --git a/pkg/services/object/acl/v2/service.go b/pkg/services/object/acl/v2/service.go index 22ccacd47..b3895df8a 100644 --- a/pkg/services/object/acl/v2/service.go +++ b/pkg/services/object/acl/v2/service.go @@ -206,7 +206,7 @@ func (b Service) Get(request *objectV2.GetRequest, stream object.GetObjectStream if !b.checker.CheckBasicACL(reqInfo) { return basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(stream.Context(), request, reqInfo); err != nil { return eACLErr(reqInfo, err) } @@ -273,13 +273,13 @@ func (b Service) Head( if !b.checker.CheckBasicACL(reqInfo) { return nil, basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } resp, err := b.next.Head(requestContext(ctx, reqInfo), request) if err == nil { - if err = b.checker.CheckEACL(resp, reqInfo); err != nil { + if err = b.checker.CheckEACL(ctx, resp, reqInfo); err != nil { err = eACLErr(reqInfo, err) } } @@ -324,7 +324,7 @@ func (b Service) Search(request *objectV2.SearchRequest, stream object.SearchStr if !b.checker.CheckBasicACL(reqInfo) { return basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(stream.Context(), request, reqInfo); err != nil { return eACLErr(reqInfo, err) } @@ -382,7 +382,7 @@ func (b Service) Delete( if !b.checker.CheckBasicACL(reqInfo) { return nil, basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } @@ -433,7 +433,7 @@ func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetOb if !b.checker.CheckBasicACL(reqInfo) { return basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(stream.Context(), request, reqInfo); err != nil { return eACLErr(reqInfo, err) } @@ -495,7 +495,7 @@ func (b Service) GetRangeHash( if !b.checker.CheckBasicACL(reqInfo) { return nil, basicACLErr(reqInfo) - } else if err := b.checker.CheckEACL(request, reqInfo); err != nil { + } else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } @@ -554,7 +554,7 @@ func (b Service) PutSingle(ctx context.Context, request *objectV2.PutSingleReque return nil, basicACLErr(reqInfo) } - if err := b.checker.CheckEACL(request, reqInfo); err != nil { + if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil { return nil, eACLErr(reqInfo, err) } @@ -667,7 +667,7 @@ func (p putStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PutR func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error { if _, ok := resp.GetBody().GetObjectPart().(*objectV2.GetObjectPartInit); ok { - if err := g.checker.CheckEACL(resp, g.info); err != nil { + if err := g.checker.CheckEACL(g.GetObjectStream.Context(), resp, g.info); err != nil { return eACLErr(g.info, err) } } @@ -676,7 +676,7 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error { } func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error { - if err := g.checker.CheckEACL(resp, g.info); err != nil { + if err := g.checker.CheckEACL(g.GetObjectRangeStream.Context(), resp, g.info); err != nil { return eACLErr(g.info, err) } @@ -684,7 +684,7 @@ func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error { } func (g *searchStreamBasicChecker) Send(resp *objectV2.SearchResponse) error { - if err := g.checker.CheckEACL(resp, g.info); err != nil { + if err := g.checker.CheckEACL(g.SearchStream.Context(), resp, g.info); err != nil { return eACLErr(g.info, err) } diff --git a/pkg/services/object/acl/v2/types.go b/pkg/services/object/acl/v2/types.go index 061cd26b6..cd139dd7f 100644 --- a/pkg/services/object/acl/v2/types.go +++ b/pkg/services/object/acl/v2/types.go @@ -1,6 +1,8 @@ package v2 import ( + "context" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user" ) @@ -12,7 +14,7 @@ type ACLChecker interface { CheckBasicACL(RequestInfo) bool // CheckEACL must return non-nil error if request // doesn't pass extended ACL validation. - CheckEACL(any, RequestInfo) error + CheckEACL(context.Context, any, RequestInfo) error // StickyBitCheck must return true only if sticky bit // is disabled or enabled but request contains correct // owner field.