frostfs-node/pkg/services/object/get/v2/util.go
Dmitrii Stepanov 89924071cd [#193] getsvc: Edit request forwarder signature
Pass context to forwarder direct, without closure.

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
2023-04-05 14:38:48 +00:00

378 lines
8.6 KiB
Go

package getsvc
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"hash"
"sync"
objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session"
"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/status"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/client"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
objectSvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object"
getsvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/get"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/util"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
versionSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/version"
"git.frostfs.info/TrueCloudLab/tzhash/tz"
)
var errWrongMessageSeq = errors.New("incorrect message sequence")
func (s *Service) toPrm(req *objectV2.GetRequest, stream objectSvc.GetObjectStream) (*getsvc.Prm, error) {
body := req.GetBody()
addrV2 := body.GetAddress()
if addrV2 == nil {
return nil, errors.New("missing object address")
}
var addr oid.Address
err := addr.ReadFromV2(*addrV2)
if err != nil {
return nil, fmt.Errorf("invalid object address: %w", err)
}
commonPrm, err := util.CommonPrmFromV2(req)
if err != nil {
return nil, err
}
streamWrapper := &streamObjectWriter{stream}
p := new(getsvc.Prm)
p.SetCommonParameters(commonPrm)
p.WithAddress(addr)
p.WithRawFlag(body.GetRaw())
p.SetObjectWriter(streamWrapper)
if !commonPrm.LocalOnly() {
forwarder := &getRequestForwarder{
OnceResign: &sync.Once{},
OnceHeaderSending: &sync.Once{},
GlobalProgress: 0,
KeyStorage: s.keyStorage,
Request: req,
Stream: streamWrapper,
}
p.SetRequestForwarder(groupAddressRequestForwarder(forwarder.forwardRequestToNode))
}
return p, nil
}
func (s *Service) toRangePrm(req *objectV2.GetRangeRequest, stream objectSvc.GetObjectRangeStream) (*getsvc.RangePrm, error) {
body := req.GetBody()
addrV2 := body.GetAddress()
if addrV2 == nil {
return nil, errors.New("missing object address")
}
var addr oid.Address
err := addr.ReadFromV2(*addrV2)
if err != nil {
return nil, fmt.Errorf("invalid object address: %w", err)
}
commonPrm, err := util.CommonPrmFromV2(req)
if err != nil {
return nil, err
}
p := new(getsvc.RangePrm)
p.SetCommonParameters(commonPrm)
streamWrapper := &streamObjectRangeWriter{stream}
p.WithAddress(addr)
p.WithRawFlag(body.GetRaw())
p.SetChunkWriter(streamWrapper)
p.SetRange(object.NewRangeFromV2(body.GetRange()))
err = p.Validate()
if err != nil {
return nil, fmt.Errorf("request params validation: %w", err)
}
if !commonPrm.LocalOnly() {
forwarder := &getRangeRequestForwarder{
OnceResign: &sync.Once{},
GlobalProgress: 0,
KeyStorage: s.keyStorage,
Request: req,
Stream: streamWrapper,
}
p.SetRequestForwarder(groupAddressRequestForwarder(forwarder.forwardRequestToNode))
}
return p, nil
}
func (s *Service) toHashRangePrm(req *objectV2.GetRangeHashRequest) (*getsvc.RangeHashPrm, error) {
body := req.GetBody()
addrV2 := body.GetAddress()
if addrV2 == nil {
return nil, errors.New("missing object address")
}
var addr oid.Address
err := addr.ReadFromV2(*addrV2)
if err != nil {
return nil, fmt.Errorf("invalid object address: %w", err)
}
commonPrm, err := util.CommonPrmFromV2(req)
if err != nil {
return nil, err
}
p := new(getsvc.RangeHashPrm)
p.SetCommonParameters(commonPrm)
p.WithAddress(addr)
if tok := commonPrm.SessionToken(); tok != nil {
signerKey, err := s.keyStorage.GetKey(&util.SessionInfo{
ID: tok.ID(),
Owner: tok.Issuer(),
})
if err != nil && errors.As(err, new(apistatus.SessionTokenNotFound)) {
commonPrm.ForgetTokens()
signerKey, err = s.keyStorage.GetKey(nil)
}
if err != nil {
return nil, fmt.Errorf("fetching session key: %w", err)
}
p.WithCachedSignerKey(signerKey)
}
rngsV2 := body.GetRanges()
rngs := make([]object.Range, len(rngsV2))
for i := range rngsV2 {
rngs[i] = *object.NewRangeFromV2(&rngsV2[i])
}
p.SetRangeList(rngs)
p.SetSalt(body.GetSalt())
switch t := body.GetType(); t {
default:
return nil, fmt.Errorf("unknown checksum type %v", t)
case refs.SHA256:
p.SetHashGenerator(func() hash.Hash {
return sha256.New()
})
case refs.TillichZemor:
p.SetHashGenerator(func() hash.Hash {
return tz.New()
})
}
return p, nil
}
type headResponseWriter struct {
mainOnly bool
body *objectV2.HeadResponseBody
}
func (w *headResponseWriter) WriteHeader(_ context.Context, hdr *object.Object) error {
if w.mainOnly {
w.body.SetHeaderPart(toShortObjectHeader(hdr))
} else {
w.body.SetHeaderPart(toFullObjectHeader(hdr))
}
return nil
}
func (s *Service) toHeadPrm(ctx context.Context, req *objectV2.HeadRequest, resp *objectV2.HeadResponse) (*getsvc.HeadPrm, error) {
body := req.GetBody()
addrV2 := body.GetAddress()
if addrV2 == nil {
return nil, errors.New("missing object address")
}
var objAddr oid.Address
err := objAddr.ReadFromV2(*addrV2)
if err != nil {
return nil, fmt.Errorf("invalid object address: %w", err)
}
commonPrm, err := util.CommonPrmFromV2(req)
if err != nil {
return nil, err
}
p := new(getsvc.HeadPrm)
p.SetCommonParameters(commonPrm)
p.WithAddress(objAddr)
p.WithRawFlag(body.GetRaw())
p.SetHeaderWriter(&headResponseWriter{
mainOnly: body.GetMainOnly(),
body: resp.GetBody(),
})
if commonPrm.LocalOnly() {
return p, nil
}
forwarder := &headRequestForwarder{
Request: req,
Response: resp,
OnceResign: &sync.Once{},
ObjectAddr: objAddr,
KeyStorage: s.keyStorage,
}
p.SetRequestForwarder(groupAddressRequestForwarder(forwarder.forwardRequestToNode))
return p, nil
}
func splitInfoResponse(info *object.SplitInfo) *objectV2.GetResponse {
resp := new(objectV2.GetResponse)
body := new(objectV2.GetResponseBody)
resp.SetBody(body)
body.SetObjectPart(info.ToV2())
return resp
}
func splitInfoRangeResponse(info *object.SplitInfo) *objectV2.GetRangeResponse {
resp := new(objectV2.GetRangeResponse)
body := new(objectV2.GetRangeResponseBody)
resp.SetBody(body)
body.SetRangePart(info.ToV2())
return resp
}
func setSplitInfoHeadResponse(info *object.SplitInfo, resp *objectV2.HeadResponse) {
resp.GetBody().SetHeaderPart(info.ToV2())
}
func toHashResponse(typ refs.ChecksumType, res *getsvc.RangeHashRes) *objectV2.GetRangeHashResponse {
resp := new(objectV2.GetRangeHashResponse)
body := new(objectV2.GetRangeHashResponseBody)
resp.SetBody(body)
body.SetType(typ)
body.SetHashList(res.Hashes())
return resp
}
func toFullObjectHeader(hdr *object.Object) objectV2.GetHeaderPart {
obj := hdr.ToV2()
hs := new(objectV2.HeaderWithSignature)
hs.SetHeader(obj.GetHeader())
hs.SetSignature(obj.GetSignature())
return hs
}
func toShortObjectHeader(hdr *object.Object) objectV2.GetHeaderPart {
hdrV2 := hdr.ToV2().GetHeader()
sh := new(objectV2.ShortHeader)
sh.SetOwnerID(hdrV2.GetOwnerID())
sh.SetCreationEpoch(hdrV2.GetCreationEpoch())
sh.SetPayloadLength(hdrV2.GetPayloadLength())
sh.SetVersion(hdrV2.GetVersion())
sh.SetObjectType(hdrV2.GetObjectType())
sh.SetHomomorphicHash(hdrV2.GetHomomorphicHash())
sh.SetPayloadHash(hdrV2.GetPayloadHash())
return sh
}
func groupAddressRequestForwarder(f func(context.Context, network.Address, client.MultiAddressClient, []byte) (*object.Object, error)) getsvc.RequestForwarder {
return func(ctx context.Context, info client.NodeInfo, c client.MultiAddressClient) (*object.Object, error) {
var (
firstErr error
res *object.Object
key = info.PublicKey()
)
info.AddressGroup().IterateAddresses(func(addr network.Address) (stop bool) {
var err error
defer func() {
stop = err == nil
if stop || firstErr == nil {
firstErr = err
}
// would be nice to log otherwise
}()
res, err = f(ctx, addr, c, key)
return
})
return res, firstErr
}
}
func writeCurrentVersion(metaHdr *session.RequestMetaHeader) {
versionV2 := new(refs.Version)
apiVersion := versionSDK.Current()
apiVersion.WriteToV2(versionV2)
metaHdr.SetVersion(versionV2)
}
func checkStatus(stV2 *status.Status) error {
if !status.IsSuccess(stV2.Code()) {
st := apistatus.FromStatusV2(stV2)
return apistatus.ErrFromStatus(st)
}
return nil
}
func chunkToSend(global, local int, chunk []byte) []byte {
if global == local {
return chunk
}
if local+len(chunk) <= global {
// chunk has already been sent
return nil
}
return chunk[global-local:]
}