aclsvc: Provide context to CheckEACL
Remove context.TODO() and allow passing RequestContext struct later. Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
parent
de139b30c0
commit
e66007b1f2
5 changed files with 30 additions and 27 deletions
|
@ -100,7 +100,7 @@ func (c *Checker) StickyBitCheck(info v2.RequestInfo, owner user.ID) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckEACL is a main check function for extended ACL.
|
// CheckEACL is a main check function for extended ACL.
|
||||||
func (c *Checker) CheckEACL(msg any, reqInfo v2.RequestInfo) error {
|
func (c *Checker) CheckEACL(ctx context.Context, msg any, reqInfo v2.RequestInfo) error {
|
||||||
basicACL := reqInfo.BasicACL()
|
basicACL := reqInfo.BasicACL()
|
||||||
if !basicACL.Extendable() {
|
if !basicACL.Extendable() {
|
||||||
return nil
|
return nil
|
||||||
|
@ -136,7 +136,7 @@ func (c *Checker) CheckEACL(msg any, reqInfo v2.RequestInfo) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
hdrSrc, err := c.getHeaderSource(cnr, msg, reqInfo)
|
hdrSrc, err := c.getHeaderSource(ctx, cnr, msg, reqInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ func getRole(reqInfo v2.RequestInfo) eaclSDK.Role {
|
||||||
return eaclRole
|
return eaclRole
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Checker) getHeaderSource(cnr cid.ID, msg any, reqInfo v2.RequestInfo) (eaclSDK.TypedHeaderSource, error) {
|
func (c *Checker) getHeaderSource(ctx context.Context, cnr cid.ID, msg any, reqInfo v2.RequestInfo) (eaclSDK.TypedHeaderSource, error) {
|
||||||
var xHeaderSource eaclV2.XHeaderSource
|
var xHeaderSource eaclV2.XHeaderSource
|
||||||
if req, ok := msg.(eaclV2.Request); ok {
|
if req, ok := msg.(eaclV2.Request); ok {
|
||||||
xHeaderSource = eaclV2.NewRequestXHeaderSource(req)
|
xHeaderSource = eaclV2.NewRequestXHeaderSource(req)
|
||||||
|
@ -181,7 +181,7 @@ func (c *Checker) getHeaderSource(cnr cid.ID, msg any, reqInfo v2.RequestInfo) (
|
||||||
xHeaderSource = eaclV2.NewResponseXHeaderSource(msg.(eaclV2.Response), reqInfo.Request().(eaclV2.Request))
|
xHeaderSource = eaclV2.NewResponseXHeaderSource(msg.(eaclV2.Response), reqInfo.Request().(eaclV2.Request))
|
||||||
}
|
}
|
||||||
|
|
||||||
hdrSrc, err := eaclV2.NewMessageHeaderSource(&localStorage{ls: c.localStorage}, xHeaderSource, cnr, reqInfo.ObjectID())
|
hdrSrc, err := eaclV2.NewMessageHeaderSource(ctx, &localStorage{ls: c.localStorage}, xHeaderSource, cnr, reqInfo.ObjectID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't parse headers: %w", err)
|
return nil, fmt.Errorf("can't parse headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,6 +103,7 @@ func TestHeadRequest(t *testing.T) {
|
||||||
|
|
||||||
newSource := func(t *testing.T) eaclSDK.TypedHeaderSource {
|
newSource := func(t *testing.T) eaclSDK.TypedHeaderSource {
|
||||||
hdrSrc, err := NewMessageHeaderSource(
|
hdrSrc, err := NewMessageHeaderSource(
|
||||||
|
context.TODO(),
|
||||||
lStorage,
|
lStorage,
|
||||||
NewRequestXHeaderSource(req),
|
NewRequestXHeaderSource(req),
|
||||||
addr.Container(),
|
addr.Container(),
|
||||||
|
|
|
@ -44,7 +44,7 @@ type headerSource struct {
|
||||||
incompleteObjectHeaders bool
|
incompleteObjectHeaders bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMessageHeaderSource(os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, objID *oid.ID) (eaclSDK.TypedHeaderSource, error) {
|
func NewMessageHeaderSource(ctx context.Context, os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, objID *oid.ID) (eaclSDK.TypedHeaderSource, error) {
|
||||||
cfg := &cfg{
|
cfg := &cfg{
|
||||||
storage: os,
|
storage: os,
|
||||||
cnr: cnrID,
|
cnr: cnrID,
|
||||||
|
@ -58,7 +58,7 @@ func NewMessageHeaderSource(os ObjectStorage, xhs XHeaderSource, cnrID cid.ID, o
|
||||||
|
|
||||||
var res headerSource
|
var res headerSource
|
||||||
|
|
||||||
err := cfg.readObjectHeaders(&res)
|
err := cfg.readObjectHeaders(ctx, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -91,18 +91,18 @@ func (x xHeader) Value() string {
|
||||||
|
|
||||||
var errMissingOID = errors.New("object ID is missing")
|
var errMissingOID = errors.New("object ID is missing")
|
||||||
|
|
||||||
func (h *cfg) readObjectHeaders(dst *headerSource) error {
|
func (h *cfg) readObjectHeaders(ctx context.Context, dst *headerSource) error {
|
||||||
switch m := h.msg.(type) {
|
switch m := h.msg.(type) {
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unexpected message type %T", h.msg))
|
panic(fmt.Sprintf("unexpected message type %T", h.msg))
|
||||||
case requestXHeaderSource:
|
case requestXHeaderSource:
|
||||||
return h.readObjectHeadersFromRequestXHeaderSource(m, dst)
|
return h.readObjectHeadersFromRequestXHeaderSource(ctx, m, dst)
|
||||||
case responseXHeaderSource:
|
case responseXHeaderSource:
|
||||||
return h.readObjectHeadersResponseXHeaderSource(m, dst)
|
return h.readObjectHeadersResponseXHeaderSource(ctx, m, dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource, dst *headerSource) error {
|
func (h *cfg) readObjectHeadersFromRequestXHeaderSource(ctx context.Context, m requestXHeaderSource, dst *headerSource) error {
|
||||||
switch req := m.req.(type) {
|
switch req := m.req.(type) {
|
||||||
case
|
case
|
||||||
*objectV2.GetRequest,
|
*objectV2.GetRequest,
|
||||||
|
@ -111,7 +111,7 @@ func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource,
|
||||||
return errMissingOID
|
return errMissingOID
|
||||||
}
|
}
|
||||||
|
|
||||||
objHeaders, completed := h.localObjectHeaders(h.cnr, h.obj)
|
objHeaders, completed := h.localObjectHeaders(ctx, h.cnr, h.obj)
|
||||||
|
|
||||||
dst.objectHeaders = objHeaders
|
dst.objectHeaders = objHeaders
|
||||||
dst.incompleteObjectHeaders = !completed
|
dst.incompleteObjectHeaders = !completed
|
||||||
|
@ -149,10 +149,10 @@ func (h *cfg) readObjectHeadersFromRequestXHeaderSource(m requestXHeaderSource,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cfg) readObjectHeadersResponseXHeaderSource(m responseXHeaderSource, dst *headerSource) error {
|
func (h *cfg) readObjectHeadersResponseXHeaderSource(ctx context.Context, m responseXHeaderSource, dst *headerSource) error {
|
||||||
switch resp := m.resp.(type) {
|
switch resp := m.resp.(type) {
|
||||||
default:
|
default:
|
||||||
objectHeaders, completed := h.localObjectHeaders(h.cnr, h.obj)
|
objectHeaders, completed := h.localObjectHeaders(ctx, h.cnr, h.obj)
|
||||||
|
|
||||||
dst.objectHeaders = objectHeaders
|
dst.objectHeaders = objectHeaders
|
||||||
dst.incompleteObjectHeaders = !completed
|
dst.incompleteObjectHeaders = !completed
|
||||||
|
@ -193,13 +193,13 @@ func (h *cfg) readObjectHeadersResponseXHeaderSource(m responseXHeaderSource, ds
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cfg) localObjectHeaders(cnr cid.ID, idObj *oid.ID) ([]eaclSDK.Header, bool) {
|
func (h *cfg) localObjectHeaders(ctx context.Context, cnr cid.ID, idObj *oid.ID) ([]eaclSDK.Header, bool) {
|
||||||
if idObj != nil {
|
if idObj != nil {
|
||||||
var addr oid.Address
|
var addr oid.Address
|
||||||
addr.SetContainer(cnr)
|
addr.SetContainer(cnr)
|
||||||
addr.SetObject(*idObj)
|
addr.SetObject(*idObj)
|
||||||
|
|
||||||
obj, err := h.storage.Head(context.TODO(), addr)
|
obj, err := h.storage.Head(ctx, addr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return headersFromObject(obj, cnr, idObj), true
|
return headersFromObject(obj, cnr, idObj), true
|
||||||
}
|
}
|
||||||
|
|
|
@ -207,7 +207,7 @@ func (b Service) Get(request *objectV2.GetRequest, stream object.GetObjectStream
|
||||||
if reqInfo.IsSoftAPECheck() {
|
if reqInfo.IsSoftAPECheck() {
|
||||||
if !b.checker.CheckBasicACL(reqInfo) {
|
if !b.checker.CheckBasicACL(reqInfo) {
|
||||||
return basicACLErr(reqInfo)
|
return basicACLErr(reqInfo)
|
||||||
} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
} else if err := b.checker.CheckEACL(stream.Context(), request, reqInfo); err != nil {
|
||||||
return eACLErr(reqInfo, err)
|
return eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,14 +276,14 @@ func (b Service) Head(
|
||||||
if reqInfo.IsSoftAPECheck() {
|
if reqInfo.IsSoftAPECheck() {
|
||||||
if !b.checker.CheckBasicACL(reqInfo) {
|
if !b.checker.CheckBasicACL(reqInfo) {
|
||||||
return nil, basicACLErr(reqInfo)
|
return nil, basicACLErr(reqInfo)
|
||||||
} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
} else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil {
|
||||||
return nil, eACLErr(reqInfo, err)
|
return nil, eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := b.next.Head(requestContext(ctx, reqInfo), request)
|
resp, err := b.next.Head(requestContext(ctx, reqInfo), request)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if err = b.checker.CheckEACL(resp, reqInfo); err != nil {
|
if err = b.checker.CheckEACL(ctx, resp, reqInfo); err != nil {
|
||||||
err = eACLErr(reqInfo, err)
|
err = eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -329,7 +329,7 @@ func (b Service) Search(request *objectV2.SearchRequest, stream object.SearchStr
|
||||||
if reqInfo.IsSoftAPECheck() {
|
if reqInfo.IsSoftAPECheck() {
|
||||||
if !b.checker.CheckBasicACL(reqInfo) {
|
if !b.checker.CheckBasicACL(reqInfo) {
|
||||||
return basicACLErr(reqInfo)
|
return basicACLErr(reqInfo)
|
||||||
} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
} else if err := b.checker.CheckEACL(stream.Context(), request, reqInfo); err != nil {
|
||||||
return eACLErr(reqInfo, err)
|
return eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -389,7 +389,7 @@ func (b Service) Delete(
|
||||||
if reqInfo.IsSoftAPECheck() {
|
if reqInfo.IsSoftAPECheck() {
|
||||||
if !b.checker.CheckBasicACL(reqInfo) {
|
if !b.checker.CheckBasicACL(reqInfo) {
|
||||||
return nil, basicACLErr(reqInfo)
|
return nil, basicACLErr(reqInfo)
|
||||||
} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
} else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil {
|
||||||
return nil, eACLErr(reqInfo, err)
|
return nil, eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -442,7 +442,7 @@ func (b Service) GetRange(request *objectV2.GetRangeRequest, stream object.GetOb
|
||||||
if reqInfo.IsSoftAPECheck() {
|
if reqInfo.IsSoftAPECheck() {
|
||||||
if !b.checker.CheckBasicACL(reqInfo) {
|
if !b.checker.CheckBasicACL(reqInfo) {
|
||||||
return basicACLErr(reqInfo)
|
return basicACLErr(reqInfo)
|
||||||
} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
} else if err := b.checker.CheckEACL(stream.Context(), request, reqInfo); err != nil {
|
||||||
return eACLErr(reqInfo, err)
|
return eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -506,7 +506,7 @@ func (b Service) GetRangeHash(
|
||||||
if reqInfo.IsSoftAPECheck() {
|
if reqInfo.IsSoftAPECheck() {
|
||||||
if !b.checker.CheckBasicACL(reqInfo) {
|
if !b.checker.CheckBasicACL(reqInfo) {
|
||||||
return nil, basicACLErr(reqInfo)
|
return nil, basicACLErr(reqInfo)
|
||||||
} else if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
} else if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil {
|
||||||
return nil, eACLErr(reqInfo, err)
|
return nil, eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -566,7 +566,7 @@ func (b Service) PutSingle(ctx context.Context, request *objectV2.PutSingleReque
|
||||||
if !b.checker.CheckBasicACL(reqInfo) || !b.checker.StickyBitCheck(reqInfo, idOwner) {
|
if !b.checker.CheckBasicACL(reqInfo) || !b.checker.StickyBitCheck(reqInfo, idOwner) {
|
||||||
return nil, basicACLErr(reqInfo)
|
return nil, basicACLErr(reqInfo)
|
||||||
}
|
}
|
||||||
if err := b.checker.CheckEACL(request, reqInfo); err != nil {
|
if err := b.checker.CheckEACL(ctx, request, reqInfo); err != nil {
|
||||||
return nil, eACLErr(reqInfo, err)
|
return nil, eACLErr(reqInfo, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -682,7 +682,7 @@ func (p putStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PutR
|
||||||
|
|
||||||
func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
|
func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
|
||||||
if _, ok := resp.GetBody().GetObjectPart().(*objectV2.GetObjectPartInit); ok {
|
if _, ok := resp.GetBody().GetObjectPart().(*objectV2.GetObjectPartInit); ok {
|
||||||
if err := g.checker.CheckEACL(resp, g.info); err != nil {
|
if err := g.checker.CheckEACL(g.GetObjectStream.Context(), resp, g.info); err != nil {
|
||||||
return eACLErr(g.info, err)
|
return eACLErr(g.info, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -691,7 +691,7 @@ func (g *getStreamBasicChecker) Send(resp *objectV2.GetResponse) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error {
|
func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error {
|
||||||
if err := g.checker.CheckEACL(resp, g.info); err != nil {
|
if err := g.checker.CheckEACL(g.GetObjectRangeStream.Context(), resp, g.info); err != nil {
|
||||||
return eACLErr(g.info, err)
|
return eACLErr(g.info, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -699,7 +699,7 @@ func (g *rangeStreamBasicChecker) Send(resp *objectV2.GetRangeResponse) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *searchStreamBasicChecker) Send(resp *objectV2.SearchResponse) error {
|
func (g *searchStreamBasicChecker) Send(resp *objectV2.SearchResponse) error {
|
||||||
if err := g.checker.CheckEACL(resp, g.info); err != nil {
|
if err := g.checker.CheckEACL(g.SearchStream.Context(), resp, g.info); err != nil {
|
||||||
return eACLErr(g.info, err)
|
return eACLErr(g.info, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package v2
|
package v2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +14,7 @@ type ACLChecker interface {
|
||||||
CheckBasicACL(RequestInfo) bool
|
CheckBasicACL(RequestInfo) bool
|
||||||
// CheckEACL must return non-nil error if request
|
// CheckEACL must return non-nil error if request
|
||||||
// doesn't pass extended ACL validation.
|
// doesn't pass extended ACL validation.
|
||||||
CheckEACL(any, RequestInfo) error
|
CheckEACL(context.Context, any, RequestInfo) error
|
||||||
// StickyBitCheck must return true only if sticky bit
|
// StickyBitCheck must return true only if sticky bit
|
||||||
// is disabled or enabled but request contains correct
|
// is disabled or enabled but request contains correct
|
||||||
// owner field.
|
// owner field.
|
||||||
|
|
Loading…
Reference in a new issue