[#32] Make basic ACL check in all object request

Signed-off-by: Alex Vanin <alexey@nspcc.ru>
This commit is contained in:
Alex Vanin 2020-09-22 19:18:41 +03:00
parent 49ee9a14a1
commit 91fef72bb6

View file

@ -1,6 +1,7 @@
package acl
import (
"bytes"
"context"
acl "github.com/nspcc-dev/neofs-api-go/pkg/acl/eacl"
@ -25,7 +26,8 @@ type (
}
getStreamBasicChecker struct {
object.GetObjectStreamer
next object.GetObjectStreamer
info requestInfo
}
searchStreamBasicChecker struct {
@ -40,6 +42,7 @@ type (
basicACL uint32
requestRole acl.Role
operation acl.Operation // put, get, head, etc.
owner *refs.OwnerID // container owner
}
)
@ -47,7 +50,7 @@ var (
ErrMalformedRequest = errors.New("malformed request")
ErrUnknownRole = errors.New("can't classify request sender")
ErrUnknownContainer = errors.New("can't fetch container info")
ErrBasicAccessDenied = errors.New("access denied by basic ACL")
ErrBasicAccessDenied = errors.New("access denied by basic acl")
)
// NewBasicChecker is a constructor for basic ACL checker of object requests.
@ -82,7 +85,11 @@ func (b BasicChecker) Get(
}
stream, err := b.next.Get(ctx, request)
return getStreamBasicChecker{stream}, err
return getStreamBasicChecker{
next: stream,
info: reqInfo,
}, err
}
func (b BasicChecker) Put(ctx context.Context) (object.PutObjectStreamer, error) {
@ -216,12 +223,17 @@ func (p putStreamBasicChecker) Send(request *object.PutRequest) error {
return err
}
owner, err := getObjectOwnerFromMessage(request)
if err != nil {
return err
}
reqInfo, err := p.source.findRequestInfo(request, cid, acl.OperationPut)
if err != nil {
return err
}
if !basicACLCheck(reqInfo) {
if !basicACLCheck(reqInfo) || !stickyBitCheck(reqInfo, owner) {
return ErrBasicAccessDenied
}
}
@ -233,6 +245,32 @@ func (p putStreamBasicChecker) CloseAndRecv() (*object.PutResponse, error) {
return p.next.CloseAndRecv()
}
func (g getStreamBasicChecker) Recv() (*object.GetResponse, error) {
resp, err := g.next.Recv()
if err != nil {
return resp, err
}
body := resp.GetBody()
if body == nil {
return resp, err
}
part := body.GetObjectPart()
if _, ok := part.(*object.GetObjectPartInit); ok {
owner, err := getObjectOwnerFromMessage(resp)
if err != nil {
return nil, err
}
if !stickyBitCheck(g.info, owner) {
return nil, ErrBasicAccessDenied
}
}
return resp, err
}
func (b BasicChecker) findRequestInfo(
req RequestV2,
cid *refs.ContainerID,
@ -257,6 +295,7 @@ func (b BasicChecker) findRequestInfo(
info.basicACL = cnr.GetBasicACL()
info.requestRole = role
info.operation = op
info.owner = cnr.GetOwnerID()
return info, nil
}
@ -294,6 +333,66 @@ func getContainerIDFromRequest(req interface{}) (id *refs.ContainerID, err error
}
}
func basicACLCheck(info requestInfo) bool {
panic("implement me")
func getObjectOwnerFromMessage(req interface{}) (id *refs.OwnerID, err error) {
defer func() {
// if there is a NPE on get body and get address
if r := recover(); r != nil {
err = ErrMalformedRequest
}
}()
switch v := req.(type) {
case *object.PutRequest:
objPart := v.GetBody().GetObjectPart()
if part, ok := objPart.(*object.PutObjectPartInit); ok {
return part.GetHeader().GetOwnerID(), nil
} else {
return nil, errors.New("can't get cid in chunk")
}
case *object.GetResponse:
objPart := v.GetBody().GetObjectPart()
if part, ok := objPart.(*object.GetObjectPartInit); ok {
return part.GetHeader().GetOwnerID(), nil
} else {
return nil, errors.New("can't get cid in chunk")
}
default:
return nil, errors.New("unsupported request type")
}
}
// main check function for basic ACL
func basicACLCheck(info requestInfo) bool {
rule := basicACLHelper(info.basicACL)
// check basic ACL permissions
var checkFn func(acl.Operation) bool
switch info.requestRole {
case acl.RoleUser:
checkFn = rule.UserAllowed
case acl.RoleSystem:
checkFn = rule.SystemAllowed
case acl.RoleOthers:
checkFn = rule.OthersAllowed
default:
// log there
return false
}
return checkFn(info.operation)
}
func stickyBitCheck(info requestInfo, owner *refs.OwnerID) bool {
if owner == nil || info.owner == nil {
return false
}
rule := basicACLHelper(info.basicACL)
if !rule.Sticky() {
return true
}
return bytes.Equal(owner.GetValue(), info.owner.GetValue())
}