diff --git a/pool/cache.go b/pool/cache.go index 1614e8a..5c39888 100644 --- a/pool/cache.go +++ b/pool/cache.go @@ -69,3 +69,7 @@ func (c *sessionCache) expired(val *cacheValue) bool { // use epoch+1 (clear cache beforehand) to prevent 'expired session token' error right after epoch tick return val.token.ExpiredAt(epoch + 1) } + +func (c *sessionCache) Epoch() uint64 { + return c.currentEpoch.Load() +} diff --git a/pool/object_put_transformer.go b/pool/object_put_transformer.go new file mode 100644 index 0000000..74c88c3 --- /dev/null +++ b/pool/object_put_transformer.go @@ -0,0 +1,107 @@ +package pool + +import ( + "context" + "crypto/ecdsa" + + sdkClient "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client" + apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" + oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/transformer" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session" +) + +type PrmObjectPutClientCutInit struct { + sdkClient.PrmObjectPutInit + key *ecdsa.PrivateKey + maxSize uint64 + epochSource transformer.EpochSource + withoutHomomorphicHash bool + stoken *session.Object +} + +func (c *clientWrapper) objectPutInitTransformer(prm PrmObjectPutClientCutInit) (*objectWriterTransformer, error) { + var w objectWriterTransformer + w.it = internalTarget{ + client: c.client, + prm: prm, + } + key := &c.prm.key + if prm.key != nil { + key = prm.key + } + + w.ot = transformer.NewPayloadSizeLimiter(transformer.Params{ + Key: key, + NextTargetInit: func() transformer.ObjectWriter { return &w.it }, + MaxSize: prm.maxSize, + WithoutHomomorphicHash: prm.withoutHomomorphicHash, + NetworkState: prm.epochSource, + SessionToken: prm.stoken, + }) + return &w, nil +} + +type objectWriterTransformer struct { + ot transformer.ChunkedObjectWriter + it internalTarget + err error +} + +func (x *objectWriterTransformer) WriteHeader(ctx context.Context, hdr object.Object) bool { + x.err = x.ot.WriteHeader(ctx, &hdr) + return x.err == nil +} + +func (x *objectWriterTransformer) WritePayloadChunk(ctx context.Context, chunk []byte) bool { + _, x.err = x.ot.Write(ctx, chunk) + return x.err == nil +} + +// ResObjectPut groups the final result values of ObjectPutInit operation. +type ResObjectPut struct { + Status apistatus.Status + OID oid.ID +} + +func (x *objectWriterTransformer) Close(ctx context.Context) (*ResObjectPut, error) { + ai, err := x.ot.Close(ctx) + if err != nil { + return nil, err + } + + if ai != nil && ai.ParentID != nil { + x.it.res.OID = *ai.ParentID + } + return &x.it.res, nil +} + +type internalTarget struct { + client *sdkClient.Client + res ResObjectPut + prm PrmObjectPutClientCutInit + useStream bool +} + +func (it *internalTarget) WriteObject(ctx context.Context, o *object.Object) error { + // todo support PutSingle + it.useStream = true + return it.putAsStream(ctx, o) +} + +func (it *internalTarget) putAsStream(ctx context.Context, o *object.Object) error { + wrt, err := it.client.ObjectPutInit(ctx, it.prm.PrmObjectPutInit) + if err != nil { + return err + } + if wrt.WriteHeader(ctx, *o) { + wrt.WritePayloadChunk(ctx, o.Payload()) + } + res, err := wrt.Close(ctx) + if res != nil { + it.res.Status = res.Status() + it.res.OID = res.StoredObjectID() + } + return err +} diff --git a/pool/parts_buffer_pool.go b/pool/parts_buffer_pool.go new file mode 100644 index 0000000..b204f50 --- /dev/null +++ b/pool/parts_buffer_pool.go @@ -0,0 +1,64 @@ +package pool + +import ( + "fmt" + "sync" +) + +type PartBuffer struct { + Buffer []byte + len uint64 +} + +type PartsBufferPool struct { + syncPool *sync.Pool + limit uint64 + maxObjectSize uint64 + + mu sync.Mutex + available uint64 +} + +func NewPartBufferPool(limit uint64, maxObjectSize uint64) *PartsBufferPool { + return &PartsBufferPool{ + limit: limit, + maxObjectSize: maxObjectSize, + available: limit, + syncPool: &sync.Pool{New: func() any { return make([]byte, maxObjectSize) }}, + } +} + +func (p *PartsBufferPool) ParBufferSize() uint64 { + return p.maxObjectSize +} + +func (p *PartsBufferPool) GetBuffer() (*PartBuffer, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.maxObjectSize > p.available { + return nil, fmt.Errorf("requested buffer size %d is greater than available: %d", p.maxObjectSize, p.available) + } + + p.available -= p.maxObjectSize + + return &PartBuffer{ + Buffer: p.syncPool.Get().([]byte), + len: p.maxObjectSize, + }, nil +} + +func (p *PartsBufferPool) FreeBuffer(buff *PartBuffer) error { + p.mu.Lock() + defer p.mu.Unlock() + + used := p.limit - p.available + if buff.len > used { + return fmt.Errorf("buffer size %d to free is greater than used: %d", buff.len, used) + } + + p.available += buff.len + p.syncPool.Put(buff.Buffer) + + return nil +} diff --git a/pool/pool.go b/pool/pool.go index 2f662d3..cf6c9b6 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -621,6 +621,14 @@ func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netm // objectPut writes object to FrostFS. func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID, error) { + if prm.clientCut { + return c.objectPutClientCut(ctx, prm) + } + + return c.objectPutServerCut(ctx, prm) +} + +func (c *clientWrapper) objectPutServerCut(ctx context.Context, prm PrmObjectPut) (oid.ID, error) { cl, err := c.getClient() if err != nil { return oid.ID{}, err @@ -702,6 +710,80 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID return res.StoredObjectID(), nil } +func (c *clientWrapper) objectPutClientCut(ctx context.Context, prm PrmObjectPut) (oid.ID, error) { + var cliPrm sdkClient.PrmObjectPutInit + cliPrm.SetCopiesNumberByVectors(prm.copiesNumber) + if prm.stoken != nil { + cliPrm.WithinSession(*prm.stoken) + } + if prm.key != nil { + cliPrm.UseKey(*prm.key) + } + if prm.btoken != nil { + cliPrm.WithBearerToken(*prm.btoken) + } + + putInitPrm := PrmObjectPutClientCutInit{ + PrmObjectPutInit: cliPrm, + key: prm.key, + maxSize: prm.networkInfo.MaxObjectSize(), + epochSource: prm.networkInfo, + stoken: prm.stoken, + } + + start := time.Now() + wObj, err := c.objectPutInitTransformer(putInitPrm) + c.incRequests(time.Since(start), methodObjectPut) + if err = c.handleError(ctx, nil, err); err != nil { + return oid.ID{}, fmt.Errorf("init writing on API client: %w", err) + } + + if wObj.WriteHeader(ctx, prm.hdr) { + if data := prm.hdr.Payload(); len(data) > 0 { + if prm.payload != nil { + prm.payload = io.MultiReader(bytes.NewReader(data), prm.payload) + } else { + prm.payload = bytes.NewReader(data) + } + } + + if prm.payload != nil { + var n int + + for { + n, err = prm.payload.Read(prm.partBuffer) + if n > 0 { + start = time.Now() + successWrite := wObj.WritePayloadChunk(ctx, prm.partBuffer[:n]) + c.incRequests(time.Since(start), methodObjectPut) + if !successWrite { + break + } + + continue + } + + if errors.Is(err, io.EOF) { + break + } + + return oid.ID{}, fmt.Errorf("read payload: %w", c.handleError(ctx, nil, err)) + } + } + } + + res, err := wObj.Close(ctx) + var st apistatus.Status + if res != nil { + st = res.Status + } + if err = c.handleError(ctx, st, err); err != nil { // here err already carries both status and client errors + return oid.ID{}, fmt.Errorf("client failure: %w", err) + } + + return res.OID, nil +} + // objectDelete invokes sdkClient.ObjectDelete parse response status to error. func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) error { cl, err := c.getClient() @@ -1064,6 +1146,7 @@ type InitParameters struct { nodeParams []NodeParam requestCallback func(RequestInfo) dialOptions []grpc.DialOption + maxClientCutMemory uint64 clientBuilder clientBuilder } @@ -1128,6 +1211,14 @@ func (x *InitParameters) SetGRPCDialOptions(opts ...grpc.DialOption) { x.dialOptions = opts } +// SetMaxClientCutMemory sets the max amount of bytes that can be used during client cut (see PrmObjectPut.SetClientCut). +// Default value is 1gb (that should be enough for 200 concurrent PUT request for MaxObjectSize 50mb). +// If the MaxObjectSize network param is greater than limit is set by this method +// Pool.PutObject operations with PrmObjectPut.SetClientCut will fail. +func (x *InitParameters) SetMaxClientCutMemory(size uint64) { + x.maxClientCutMemory = size +} + // setClientBuilder sets clientBuilder used for client construction. // Wraps setClientBuilderContext without a context. func (x *InitParameters) setClientBuilder(builder clientBuilder) { @@ -1291,6 +1382,10 @@ type PrmObjectPut struct { payload io.Reader copiesNumber []uint32 + + clientCut bool + partBuffer []byte + networkInfo netmap.NetworkInfo } // SetHeader specifies header of the object. @@ -1315,6 +1410,24 @@ func (x *PrmObjectPut) SetCopiesNumberVector(copiesNumber []uint32) { x.copiesNumber = copiesNumber } +// SetClientCut enables client cut for objects. It means that full object is prepared on client side +// and retrying is possible. But this leads to additional memory using for buffering object parts. +// Buffer size for every put is MaxObjectSize value from FrostFS network. +// There is limit for total memory allocation for in-flight request and +// can be set by InitParameters.SetMaxClientCutMemory (default value is 1gb). +// Put requests will fail if this limit be reached. +func (x *PrmObjectPut) SetClientCut(clientCut bool) { + x.clientCut = clientCut +} + +func (x *PrmObjectPut) setPartBuffer(partBuffer []byte) { + x.partBuffer = partBuffer +} + +func (x *PrmObjectPut) setNetworkInfo(ni netmap.NetworkInfo) { + x.networkInfo = ni +} + // PrmObjectDelete groups parameters of DeleteObject operation. type PrmObjectDelete struct { prmCommon @@ -1600,6 +1713,11 @@ type Pool struct { rebalanceParams rebalanceParameters clientBuilder clientBuilder logger *zap.Logger + + // we cannot initialize partBufferPool in NewPool function, + // so we have to save maxClientCutMemory param for further initialization in Dial. + maxClientCutMemory uint64 + partsBufferPool *PartsBufferPool } type innerPool struct { @@ -1611,6 +1729,7 @@ type innerPool struct { const ( defaultSessionTokenExpirationDuration = 100 // in blocks defaultErrorThreshold = 100 + defaultMaxClientCutMemory = 1024 * 1024 * 1024 // 1gb defaultRebalanceInterval = 15 * time.Second defaultHealthcheckTimeout = 4 * time.Second @@ -1647,7 +1766,8 @@ func NewPool(options InitParameters) (*Pool, error) { clientRebalanceInterval: options.clientRebalanceInterval, sessionExpirationDuration: options.sessionExpirationDuration, }, - clientBuilder: options.clientBuilder, + clientBuilder: options.clientBuilder, + maxClientCutMemory: options.maxClientCutMemory, } return pool, nil @@ -1675,7 +1795,7 @@ func (p *Pool) Dial(ctx context.Context) error { } var st session.Object - err := initSessionForDuration(ctx, &st, clients[j], p.rebalanceParams.sessionExpirationDuration, *p.key) + err := initSessionForDuration(ctx, &st, clients[j], p.rebalanceParams.sessionExpirationDuration, *p.key, false) if err != nil { clients[j].setUnhealthy() p.log(zap.WarnLevel, "failed to create frostfs session token for client", @@ -1683,7 +1803,7 @@ func (p *Pool) Dial(ctx context.Context) error { continue } - _ = p.cache.Put(formCacheKey(addr, p.key), st) + _ = p.cache.Put(formCacheKey(addr, p.key, false), st) atLeastOneHealthy = true } source := rand.NewSource(time.Now().UnixNano()) @@ -1704,6 +1824,12 @@ func (p *Pool) Dial(ctx context.Context) error { p.closedCh = make(chan struct{}) p.innerPools = inner + ni, err := p.NetworkInfo(ctx) + if err != nil { + return fmt.Errorf("get network info for max object size: %w", err) + } + p.partsBufferPool = NewPartBufferPool(p.maxClientCutMemory, ni.MaxObjectSize()) + go p.startRebalance(ctx) return nil } @@ -1725,6 +1851,10 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { params.errorThreshold = defaultErrorThreshold } + if params.maxClientCutMemory == 0 { + params.maxClientCutMemory = defaultMaxClientCutMemory + } + if params.clientRebalanceInterval <= 0 { params.clientRebalanceInterval = defaultRebalanceInterval } @@ -1914,9 +2044,15 @@ func (p *innerPool) connection() (client, error) { return nil, errors.New("no healthy client") } -func formCacheKey(address string, key *ecdsa.PrivateKey) string { +func formCacheKey(address string, key *ecdsa.PrivateKey, clientCut bool) string { k := keys.PrivateKey{PrivateKey: *key} - return address + k.String() + + stype := "server" + if clientCut { + stype = "client" + } + + return address + stype + k.String() } func (p *Pool) checkSessionTokenErr(err error, address string) bool { @@ -1932,7 +2068,7 @@ func (p *Pool) checkSessionTokenErr(err error, address string) bool { return false } -func initSessionForDuration(ctx context.Context, dst *session.Object, c client, dur uint64, ownerKey ecdsa.PrivateKey) error { +func initSessionForDuration(ctx context.Context, dst *session.Object, c client, dur uint64, ownerKey ecdsa.PrivateKey, clientCut bool) error { ni, err := c.networkInfo(ctx, prmNetworkInfo{}) if err != nil { return err @@ -1950,23 +2086,26 @@ func initSessionForDuration(ctx context.Context, dst *session.Object, c client, prm.setExp(exp) prm.useKey(ownerKey) - res, err := c.sessionCreate(ctx, prm) - if err != nil { - return err - } + var ( + id uuid.UUID + key frostfsecdsa.PublicKey + ) - var id uuid.UUID + if clientCut { + id = uuid.New() + key = frostfsecdsa.PublicKey(ownerKey.PublicKey) - err = id.UnmarshalBinary(res.id) - if err != nil { - return fmt.Errorf("invalid session token ID: %w", err) - } - - var key frostfsecdsa.PublicKey - - err = key.Decode(res.sessionKey) - if err != nil { - return fmt.Errorf("invalid public session key: %w", err) + } else { + res, err := c.sessionCreate(ctx, prm) + if err != nil { + return err + } + if err = id.UnmarshalBinary(res.id); err != nil { + return fmt.Errorf("invalid session token ID: %w", err) + } + if err = key.Decode(res.sessionKey); err != nil { + return fmt.Errorf("invalid public session key: %w", err) + } } dst.SetID(id) @@ -1992,6 +2131,8 @@ type callContext struct { sessionCnr cid.ID sessionObjSet bool sessionObjs []oid.ID + + sessionClientCut bool } func (p *Pool) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContext) error { @@ -2028,12 +2169,12 @@ func (p *Pool) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContex // opens new session or uses cached one. // Must be called only on initialized callContext with set sessionTarget. func (p *Pool) openDefaultSession(ctx context.Context, cc *callContext) error { - cacheKey := formCacheKey(cc.endpoint, cc.key) + cacheKey := formCacheKey(cc.endpoint, cc.key, cc.sessionClientCut) tok, ok := p.cache.Get(cacheKey) if !ok { // init new session - err := initSessionForDuration(ctx, &tok, cc.client, p.stokenDuration, *cc.key) + err := initSessionForDuration(ctx, &tok, cc.client, p.stokenDuration, *cc.key, cc.sessionClientCut) if err != nil { return fmt.Errorf("session API client: %w", err) } @@ -2098,6 +2239,7 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (oid.ID, error) p.fillAppropriateKey(&prm.prmCommon) var ctxCall callContext + ctxCall.sessionClientCut = prm.clientCut if err := p.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { return oid.ID{}, fmt.Errorf("init call context: %w", err) } @@ -2109,6 +2251,25 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (oid.ID, error) } } + buff, err := p.partsBufferPool.GetBuffer() + if err != nil { + return oid.ID{}, fmt.Errorf("cannot get buffer for put operations: %w", err) + } + + defer func() { + if errFree := p.partsBufferPool.FreeBuffer(buff); errFree != nil { + p.log(zap.WarnLevel, "failed to free part buffer", zap.Error(err)) + } + }() + + prm.setPartBuffer(buff.Buffer) + + var ni netmap.NetworkInfo + ni.SetCurrentEpoch(p.cache.Epoch()) + ni.SetMaxObjectSize(p.partsBufferPool.ParBufferSize()) // we want to use initial max object size in PayloadSizeLimiter + + prm.setNetworkInfo(ni) + id, err := ctxCall.client.objectPut(ctx, prm) if err != nil { // removes session token from cache in case of token error diff --git a/pool/pool_test.go b/pool/pool_test.go index bc0407e..b5a6893 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -106,7 +106,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { if err != nil { return false } - st, _ := clientPool.cache.Get(formCacheKey(cp.address(), clientPool.key)) + st, _ := clientPool.cache.Get(formCacheKey(cp.address(), clientPool.key, false)) return st.AssertAuthKey(&expectedAuthKey) } require.Never(t, condition, 900*time.Millisecond, 100*time.Millisecond) @@ -141,7 +141,7 @@ func TestOneNode(t *testing.T) { cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey) require.True(t, st.AssertAuthKey(&expectedAuthKey)) } @@ -171,7 +171,7 @@ func TestTwoNodes(t *testing.T) { cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) } @@ -226,7 +226,7 @@ func TestOneOfTwoFailed(t *testing.T) { for i := 0; i < 5; i++ { cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) } } @@ -296,7 +296,7 @@ func TestSessionCache(t *testing.T) { // cache must contain session token cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) var prm PrmObjectGet @@ -309,7 +309,7 @@ func TestSessionCache(t *testing.T) { // cache must not contain session token cp, err = pool.connection() require.NoError(t, err) - _, ok := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + _, ok := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.False(t, ok) var prm2 PrmObjectPut @@ -321,7 +321,7 @@ func TestSessionCache(t *testing.T) { // cache must contain session token cp, err = pool.connection() require.NoError(t, err) - st, _ = pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ = pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) } @@ -365,7 +365,7 @@ func TestPriority(t *testing.T) { firstNode := func() bool { cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) return st.AssertAuthKey(&expectedAuthKey1) } @@ -373,7 +373,7 @@ func TestPriority(t *testing.T) { secondNode := func() bool { cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) return st.AssertAuthKey(&expectedAuthKey2) } require.Never(t, secondNode, time.Second, 200*time.Millisecond) @@ -410,7 +410,7 @@ func TestSessionCacheWithKey(t *testing.T) { // cache must contain session token cp, err := pool.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key)) + st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) var prm PrmObjectDelete @@ -420,7 +420,7 @@ func TestSessionCacheWithKey(t *testing.T) { err = pool.DeleteObject(ctx, prm) require.NoError(t, err) - st, _ = pool.cache.Get(formCacheKey(cp.address(), anonKey)) + st, _ = pool.cache.Get(formCacheKey(cp.address(), anonKey, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) }