object: Implement Patch method #1307

Merged
fyrchik merged 5 commits from aarifullin/frostfs-node:feat/patch/1 into master 2024-08-16 14:13:12 +00:00
5 changed files with 183 additions and 20 deletions
Showing only changes of commit 4a7c858124 - Show all commits

View file

@ -35,6 +35,12 @@ type putStreamBasicChecker struct {
next object.PutObjectStream next object.PutObjectStream
} }
type patchStreamBasicChecker struct {
source *Service
next object.PatchObjectStream
nonFirstSend bool
}
type getStreamBasicChecker struct { type getStreamBasicChecker struct {
checker ACLChecker checker ACLChecker
@ -250,7 +256,12 @@ func (b Service) Put() (object.PutObjectStream, error) {
} }
func (b Service) Patch() (object.PatchObjectStream, error) { func (b Service) Patch() (object.PatchObjectStream, error) {
return b.next.Patch() streamer, err := b.next.Patch()
return &patchStreamBasicChecker{
source: &b,
next: streamer,
}, err
} }
func (b Service) Head( func (b Service) Head(
@ -738,6 +749,65 @@ func (g *searchStreamBasicChecker) Send(resp *objectV2.SearchResponse) error {
return g.SearchStream.Send(resp) return g.SearchStream.Send(resp)
} }
func (p *patchStreamBasicChecker) Send(ctx context.Context, request *objectV2.PatchRequest) error {
body := request.GetBody()
if body == nil {
return errEmptyBody
}
if !p.nonFirstSend {
p.nonFirstSend = true
cnr, err := getContainerIDFromRequest(request)
if err != nil {
return err
}
objV2 := request.GetBody().GetAddress().GetObjectID()
if objV2 == nil {
return errors.New("missing oid")
}
obj := new(oid.ID)
err = obj.ReadFromV2(*objV2)
if err != nil {
return err
}
var sTok *sessionSDK.Object
sTok, err = readSessionToken(cnr, obj, request.GetMetaHeader().GetSessionToken())
if err != nil {
return err
}
bTok, err := originalBearerToken(request.GetMetaHeader())
if err != nil {
return err
}
req := MetaWithToken{
vheader: request.GetVerificationHeader(),
token: sTok,
bearer: bTok,
src: request,
}
reqInfo, err := p.source.findRequestInfoWithoutACLOperationAssert(req, cnr)
if err != nil {
return err
}
reqInfo.obj = obj
ctx = requestContext(ctx, reqInfo)
}
return p.next.Send(ctx, request)
}
func (p patchStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PatchResponse, error) {
return p.next.CloseAndRecv(ctx)
}
func (b Service) findRequestInfo(req MetaWithToken, idCnr cid.ID, op acl.Op) (info RequestInfo, err error) { func (b Service) findRequestInfo(req MetaWithToken, idCnr cid.ID, op acl.Op) (info RequestInfo, err error) {
cnr, err := b.containers.Get(idCnr) // fetch actual container cnr, err := b.containers.Get(idCnr) // fetch actual container
if err != nil { if err != nil {
@ -794,3 +864,56 @@ func (b Service) findRequestInfo(req MetaWithToken, idCnr cid.ID, op acl.Op) (in
return info, nil return info, nil
} }
// findRequestInfoWithoutACLOperationAssert is findRequestInfo without session token verb assert.
func (b Service) findRequestInfoWithoutACLOperationAssert(req MetaWithToken, idCnr cid.ID) (info RequestInfo, err error) {
cnr, err := b.containers.Get(idCnr) // fetch actual container
if err != nil {
return info, err
}
if req.token != nil {
currentEpoch, err := b.nm.Epoch()
if err != nil {
return info, errors.New("can't fetch current epoch")
}
if req.token.ExpiredAt(currentEpoch) {
return info, new(apistatus.SessionTokenExpired)
}
if req.token.InvalidAt(currentEpoch) {
return info, fmt.Errorf("%s: token is invalid at %d epoch)",
invalidRequestMessage, currentEpoch)
}
}
// find request role and key
ownerID, ownerKey, err := req.RequestOwner()
if err != nil {
return info, err
}
res, err := b.c.Classify(ownerID, ownerKey, idCnr, cnr.Value)
if err != nil {
return info, err
}
info.basicACL = cnr.Value.BasicACL()
info.requestRole = res.Role
info.cnrOwner = cnr.Value.Owner()
info.idCnr = idCnr
cnrNamespace, hasNamespace := strings.CutSuffix(cnrSDK.ReadDomain(cnr.Value).Zone(), ".ns")
if hasNamespace {
info.cnrNamespace = cnrNamespace
}
// it is assumed that at the moment the key will be valid,
// otherwise the request would not pass validation
info.senderKey = res.Key
// add bearer token if it is present in request
info.bearer = req.bearer
info.srcRequest = req.src
return info, nil
}

View file

@ -46,6 +46,8 @@ func getContainerIDFromRequest(req any) (cid.ID, error) {
idV2 = v.GetBody().GetAddress().GetContainerID() idV2 = v.GetBody().GetAddress().GetContainerID()
case *objectV2.PutSingleRequest: case *objectV2.PutSingleRequest:
idV2 = v.GetBody().GetObject().GetHeader().GetContainerID() idV2 = v.GetBody().GetObject().GetHeader().GetContainerID()
case *objectV2.PatchRequest:
idV2 = v.GetBody().GetAddress().GetContainerID()
default: default:
return cid.ID{}, errors.New("unknown request type") return cid.ID{}, errors.New("unknown request type")
} }

View file

@ -518,22 +518,7 @@ func TestAPECheck_BearerTokenOverrides(t *testing.T) {
ls := inmemory.NewInmemoryLocalStorage() ls := inmemory.NewInmemoryLocalStorage()
ms := inmemory.NewInmemoryMorphRuleChainStorage() ms := inmemory.NewInmemoryMorphRuleChainStorage()
node1Key, err := keys.NewPrivateKey() checker := NewChecker(ls, ms, headerProvider, frostfsidProvider, nil, &stMock{}, nil, nil)
require.NoError(t, err)
node1 := netmapSDK.NodeInfo{}
node1.SetPublicKey(node1Key.PublicKey().Bytes())
netmap := &netmapSDK.NetMap{}
netmap.SetEpoch(100)
netmap.SetNodes([]netmapSDK.NodeInfo{node1})
nm := &netmapStub{
currentEpoch: 100,
netmaps: map[uint64]*netmapSDK.NetMap{
100: netmap,
},
}
checker := NewChecker(ls, ms, headerProvider, frostfsidProvider, nm, &stMock{}, nil, nil)
prm := Prm{ prm := Prm{
Method: method, Method: method,
@ -556,7 +541,7 @@ func TestAPECheck_BearerTokenOverrides(t *testing.T) {
} }
} }
err = checker.CheckAPE(context.Background(), prm) err := checker.CheckAPE(context.Background(), prm)
if test.expectAPEErr { if test.expectAPEErr {
require.Error(t, err) require.Error(t, err)
} else { } else {

View file

@ -103,7 +103,8 @@ func (c *checkerImpl) newAPERequest(ctx context.Context, prm Prm) (aperequest.Re
nativeschema.MethodHeadObject, nativeschema.MethodHeadObject,
nativeschema.MethodRangeObject, nativeschema.MethodRangeObject,
nativeschema.MethodHashObject, nativeschema.MethodHashObject,
nativeschema.MethodDeleteObject: nativeschema.MethodDeleteObject,
nativeschema.MethodPatchObject:
if prm.Object == nil { if prm.Object == nil {
return defaultRequest, fmt.Errorf("method %s: %w", prm.Method, errMissingOID) return defaultRequest, fmt.Errorf("method %s: %w", prm.Method, errMissingOID)
} }

View file

@ -204,8 +204,60 @@ func (c *Service) Put() (objectSvc.PutObjectStream, error) {
}, err }, err
} }
type patchStreamBasicChecker struct {
apeChecker Checker
next objectSvc.PatchObjectStream
nonFirstSend bool
}
func (p *patchStreamBasicChecker) Send(ctx context.Context, request *objectV2.PatchRequest) error {
if !p.nonFirstSend {
p.nonFirstSend = true
reqCtx, err := requestContext(ctx)
if err != nil {
return toStatusErr(err)
}
cnrID, objID, err := getAddressParamsSDK(request.GetBody().GetAddress().GetContainerID(), request.GetBody().GetAddress().GetObjectID())
if err != nil {
return toStatusErr(err)
}
prm := Prm{
Namespace: reqCtx.Namespace,
Container: cnrID,
Object: objID,
Method: nativeschema.MethodPatchObject,
SenderKey: hex.EncodeToString(reqCtx.SenderKey),
ContainerOwner: reqCtx.ContainerOwner,
Role: nativeSchemaRole(reqCtx.Role),
SoftAPECheck: reqCtx.SoftAPECheck,
BearerToken: reqCtx.BearerToken,
XHeaders: request.GetMetaHeader().GetXHeaders(),
}
if err := p.apeChecker.CheckAPE(ctx, prm); err != nil {
return toStatusErr(err)
}
}
return p.next.Send(ctx, request)
}
func (p patchStreamBasicChecker) CloseAndRecv(ctx context.Context) (*objectV2.PatchResponse, error) {
return p.next.CloseAndRecv(ctx)
}
func (c *Service) Patch() (objectSvc.PatchObjectStream, error) { func (c *Service) Patch() (objectSvc.PatchObjectStream, error) {
return c.next.Patch() streamer, err := c.next.Patch()
return &patchStreamBasicChecker{
apeChecker: c.apeChecker,
next: streamer,
}, err
} }
func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*objectV2.HeadResponse, error) { func (c *Service) Head(ctx context.Context, request *objectV2.HeadRequest) (*objectV2.HeadResponse, error) {