diff --git a/go.sum b/go.sum index ff4236fc38..1986cbbd28 100644 Binary files a/go.sum and b/go.sum differ diff --git a/pkg/services/object/acl/basic.go b/pkg/services/object/acl/basic.go index fb6289109b..40ba1c40e1 100644 --- a/pkg/services/object/acl/basic.go +++ b/pkg/services/object/acl/basic.go @@ -12,6 +12,7 @@ import ( type ( // ContainerGetter accesses NeoFS container storage. + // fixme: use core.container interface implementation ContainerGetter interface { Get(*refs.ContainerID) (*container.Container, error) } @@ -22,41 +23,52 @@ type ( // BasicChecker checks basic ACL rules. BasicChecker struct { - sender SenderClassifier - next object.Service + containers ContainerGetter + sender SenderClassifier + next object.Service } putStreamBasicChecker struct { - sender SenderClassifier + source *BasicChecker next object.PutObjectStreamer } getStreamBasicChecker struct { - sender SenderClassifier - next object.GetObjectStreamer + object.GetObjectStreamer } searchStreamBasicChecker struct { - sender SenderClassifier - next object.SearchObjectStreamer + object.SearchObjectStreamer } getRangeStreamBasicChecker struct { - sender SenderClassifier - next object.GetRangeObjectStreamer + object.GetRangeObjectStreamer + } + + requestInfo struct { + basicACL uint32 + requestRole acl.Role + operation acl.Operation // put, get, head, etc. } ) var ( - ErrMalformedRequest = errors.New("malformed request") - ErrUnknownRole = errors.New("can't classify request sender") + ErrMalformedRequest = errors.New("malformed request") + ErrUnknownRole = errors.New("can't classify request sender") + ErrUnknownContainer = errors.New("can't fetch container info") + ErrBasicAccessDenied = errors.New("access denied by basic ACL") ) // NewBasicChecker is a constructor for basic ACL checker of object requests. -func NewBasicChecker(c SenderClassifier, next object.Service) BasicChecker { +func NewBasicChecker( + c SenderClassifier, + cnr ContainerGetter, + next object.Service) BasicChecker { + return BasicChecker{ - sender: c, - next: next, + containers: cnr, + sender: c, + next: next, } } @@ -64,31 +76,29 @@ func (b BasicChecker) Get( ctx context.Context, request *object.GetRequest) (object.GetObjectStreamer, error) { - // get container address and do not panic at malformed request - var addr *refs.Address - if body := request.GetBody(); body == nil { - return nil, ErrMalformedRequest - } else { - addr = body.GetAddress() + cid, err := getContainerIDFromRequest(request) + if err != nil { + return nil, err } - role := b.sender.Classify(request, addr.GetContainerID()) - if role == acl.RoleUnknown { - return nil, ErrUnknownRole + reqInfo, err := b.findRequestInfo(request, cid, acl.OperationGet) + if err != nil { + return nil, err + } + + if !basicACLCheck(reqInfo) { + return nil, ErrBasicAccessDenied } stream, err := b.next.Get(ctx, request) - return getStreamBasicChecker{ - sender: b.sender, - next: stream, - }, err + return getStreamBasicChecker{stream}, err } func (b BasicChecker) Put(ctx context.Context) (object.PutObjectStreamer, error) { streamer, err := b.next.Put(ctx) return putStreamBasicChecker{ - sender: b.sender, + source: &b, next: streamer, }, err } @@ -97,6 +107,20 @@ func (b BasicChecker) Head( ctx context.Context, request *object.HeadRequest) (*object.HeadResponse, error) { + cid, err := getContainerIDFromRequest(request) + if err != nil { + return nil, err + } + + reqInfo, err := b.findRequestInfo(request, cid, acl.OperationHead) + if err != nil { + return nil, err + } + + if !basicACLCheck(reqInfo) { + return nil, ErrBasicAccessDenied + } + return b.next.Head(ctx, request) } @@ -104,17 +128,44 @@ func (b BasicChecker) Search( ctx context.Context, request *object.SearchRequest) (object.SearchObjectStreamer, error) { + var cid *refs.ContainerID + + cid, err := getContainerIDFromRequest(request) + if err != nil { + return nil, err + } + + reqInfo, err := b.findRequestInfo(request, cid, acl.OperationSearch) + if err != nil { + return nil, err + } + + if !basicACLCheck(reqInfo) { + return nil, ErrBasicAccessDenied + } + stream, err := b.next.Search(ctx, request) - return searchStreamBasicChecker{ - sender: b.sender, - next: stream, - }, err + return searchStreamBasicChecker{stream}, err } func (b BasicChecker) Delete( ctx context.Context, request *object.DeleteRequest) (*object.DeleteResponse, error) { + cid, err := getContainerIDFromRequest(request) + if err != nil { + return nil, err + } + + reqInfo, err := b.findRequestInfo(request, cid, acl.OperationDelete) + if err != nil { + return nil, err + } + + if !basicACLCheck(reqInfo) { + return nil, ErrBasicAccessDenied + } + return b.next.Delete(ctx, request) } @@ -122,21 +173,68 @@ func (b BasicChecker) GetRange( ctx context.Context, request *object.GetRangeRequest) (object.GetRangeObjectStreamer, error) { + cid, err := getContainerIDFromRequest(request) + if err != nil { + return nil, err + } + + reqInfo, err := b.findRequestInfo(request, cid, acl.OperationRange) + if err != nil { + return nil, err + } + + if !basicACLCheck(reqInfo) { + return nil, ErrBasicAccessDenied + } + stream, err := b.next.GetRange(ctx, request) - return getRangeStreamBasicChecker{ - sender: b.sender, - next: stream, - }, err + return getRangeStreamBasicChecker{stream}, err } func (b BasicChecker) GetRangeHash( ctx context.Context, request *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) { + cid, err := getContainerIDFromRequest(request) + if err != nil { + return nil, err + } + + reqInfo, err := b.findRequestInfo(request, cid, acl.OperationRangeHash) + if err != nil { + return nil, err + } + + if !basicACLCheck(reqInfo) { + return nil, ErrBasicAccessDenied + } + return b.next.GetRangeHash(ctx, request) } func (p putStreamBasicChecker) Send(request *object.PutRequest) error { + body := request.GetBody() + if body == nil { + return ErrMalformedRequest + } + + part := body.GetObjectPart() + if _, ok := part.(*object.PutObjectPartInit); ok { + cid, err := getContainerIDFromRequest(request) + if err != nil { + return err + } + + reqInfo, err := p.source.findRequestInfo(request, cid, acl.OperationPut) + if err != nil { + return err + } + + if !basicACLCheck(reqInfo) { + return ErrBasicAccessDenied + } + } + return p.next.Send(request) } @@ -144,14 +242,63 @@ func (p putStreamBasicChecker) CloseAndRecv() (*object.PutResponse, error) { return p.next.CloseAndRecv() } -func (g getStreamBasicChecker) Recv() (*object.GetResponse, error) { - return g.next.Recv() +func (b BasicChecker) findRequestInfo( + req RequestV2, + cid *refs.ContainerID, + op acl.Operation) (info requestInfo, err error) { + + // fetch actual container + cnr, err := b.containers.Get(cid) + if err != nil || cnr.GetOwnerID() == nil { + return info, ErrUnknownContainer + } + + // find request role + role := b.sender.Classify(req, cid, cnr) + if role == acl.RoleUnknown { + return info, ErrUnknownRole + } + + info.basicACL = cnr.GetBasicACL() + info.requestRole = role + info.operation = op + + return info, nil } -func (s searchStreamBasicChecker) Recv() (*object.SearchResponse, error) { - return s.next.Recv() +func getContainerIDFromRequest(req interface{}) (id *refs.ContainerID, err error) { + defer func() { + // if there is a NPE on get body and get address + if r := recover(); r != nil { + err = ErrMalformedRequest + } + }() + + switch v := req.(type) { + case *object.GetRequest: + return v.GetBody().GetAddress().GetContainerID(), nil + case *object.PutRequest: + objPart := v.GetBody().GetObjectPart() + if part, ok := objPart.(*object.PutObjectPartInit); ok { + return part.GetHeader().GetContainerID(), nil + } else { + return nil, errors.New("can't get cid in chunk") + } + case *object.HeadRequest: + return v.GetBody().GetAddress().GetContainerID(), nil + case *object.SearchRequest: + return v.GetBody().GetContainerID(), nil + case *object.DeleteRequest: + return v.GetBody().GetAddress().GetContainerID(), nil + case *object.GetRangeRequest: + return v.GetBody().GetAddress().GetContainerID(), nil + case *object.GetRangeHashRequest: + return v.GetBody().GetAddress().GetContainerID(), nil + default: + return nil, errors.New("unknown request type") + } } -func (g getRangeStreamBasicChecker) Recv() (*object.GetRangeResponse, error) { - return g.next.Recv() +func basicACLCheck(info requestInfo) bool { + panic("implement me") } diff --git a/pkg/services/object/acl/classifier.go b/pkg/services/object/acl/classifier.go index dfe1e646e7..9b71a594a4 100644 --- a/pkg/services/object/acl/classifier.go +++ b/pkg/services/object/acl/classifier.go @@ -15,12 +15,6 @@ import ( ) type ( - // ContainerFetcher accesses NeoFS container storage. - // fixme: use core.container interface implementation - ContainerFetcher interface { - Fetch(*refs.ContainerID) (*container.Container, error) - } - // fixme: use core.netmap interface implementation NetmapFetcher interface { Current() (netmap.Netmap, error) @@ -37,18 +31,23 @@ type ( } SenderClassifier struct { - containers ContainerFetcher - innerRing InnerRingFetcher - netmap NetmapFetcher + innerRing InnerRingFetcher + netmap NetmapFetcher } ) // fixme: update classifier constructor -func NewSenderClassifier() SenderClassifier { - return SenderClassifier{} +func NewSenderClassifier(ir InnerRingFetcher, nm NetmapFetcher) SenderClassifier { + return SenderClassifier{ + innerRing: ir, + netmap: nm, + } } -func (c SenderClassifier) Classify(req RequestV2, cid *refs.ContainerID) acl.Role { +func (c SenderClassifier) Classify( + req RequestV2, + cid *refs.ContainerID, + cnr *container.Container) acl.Role { if cid == nil || req == nil { // log there return acl.RoleUnknown @@ -62,15 +61,8 @@ func (c SenderClassifier) Classify(req RequestV2, cid *refs.ContainerID) acl.Rol // todo: get owner from neofs.id if present - // fetch actual container - cnr, err := c.containers.Fetch(cid) - if err != nil || cnr.GetOwnerID() == nil { - // log there - return acl.RoleUnknown - } - // if request owner is the same as container owner, return RoleUser - if bytes.Equal(cnr.GetOwnerID().GetValue(), cid.GetValue()) { + if bytes.Equal(cnr.GetOwnerID().GetValue(), ownerID.GetValue()) { return acl.RoleUser }