diff --git a/go.mod b/go.mod index e000c68..91beb9b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.16 require ( github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521073959-f0d4d129b7f1 + github.com/bluele/gcache v0.0.2 github.com/golang/mock v1.6.0 github.com/google/uuid v1.2.0 github.com/mr-tron/base58 v1.2.0 diff --git a/go.sum b/go.sum index c54d774..773a7df 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= +github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btcd v0.22.0-beta h1:LTDpDKUM5EeOFBPM8IXpinEcmZ6FWfNZbE3lfrfdnWo= github.com/btcsuite/btcd v0.22.0-beta/go.mod h1:9n5ntfhhHQBIhUvlhDvD3Qg6fRUj4jkN0VB8L8svzOA= diff --git a/pool/cache.go b/pool/cache.go new file mode 100644 index 0000000..39e56b8 --- /dev/null +++ b/pool/cache.go @@ -0,0 +1,47 @@ +package pool + +import ( + "strings" + + "github.com/bluele/gcache" + "github.com/nspcc-dev/neofs-api-go/pkg/session" +) + +type SessionCache struct { + cache gcache.Cache +} + +func NewCache() *SessionCache { + return &SessionCache{ + cache: gcache.New(100).Build(), + } +} + +func (c *SessionCache) Get(key string) *session.Token { + tokenRaw, err := c.cache.Get(key) + if err != nil { + return nil + } + token, ok := tokenRaw.(*session.Token) + if !ok { + return nil + } + + return token +} + +func (c *SessionCache) Put(key string, token *session.Token) error { + return c.cache.Set(key, token) +} + +func (c *SessionCache) DeleteByPrefix(prefix string) { + for _, key := range c.cache.Keys(false) { + keyStr, ok := key.(string) + if !ok { + continue + } + if strings.HasPrefix(keyStr, prefix) { + c.cache.Remove(key) + } + } +} diff --git a/pool/pool.go b/pool/pool.go index 34e88c8..dff8711 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -5,10 +5,13 @@ import ( "crypto/ecdsa" "errors" "fmt" + "math" "math/rand" + "strings" "sync" "time" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neofs-sdk-go/client" "github.com/nspcc-dev/neofs-sdk-go/container" cid "github.com/nspcc-dev/neofs-sdk-go/container/id" @@ -89,12 +92,34 @@ type Pool interface { OwnerID() *owner.ID WaitForContainerPresence(context.Context, *cid.ID, *ContainerPollingParams) error Close() + + PutObjectParam(ctx context.Context, params *client.PutObjectParams, callParam CallParam) (*object.ID, error) + DeleteObjectParam(ctx context.Context, params *client.DeleteObjectParams, callParam CallParam) error + GetObjectParam(ctx context.Context, params *client.GetObjectParams, callParam CallParam) (*object.Object, error) + GetObjectHeaderParam(ctx context.Context, params *client.ObjectHeaderParams, callParam CallParam) (*object.Object, error) + ObjectPayloadRangeDataParam(ctx context.Context, params *client.RangeDataParams, callParam CallParam) ([]byte, error) + ObjectPayloadRangeSHA256Param(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][32]byte, error) + ObjectPayloadRangeTZParam(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][64]byte, error) + SearchObjectParam(ctx context.Context, params *client.SearchObjectParams, callParam CallParam) ([]*object.ID, error) + PutContainerParam(ctx context.Context, cnr *container.Container, callParam CallParam) (*cid.ID, error) + GetContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*container.Container, error) + ListContainersParam(ctx context.Context, ownerID *owner.ID, callParam CallParam) ([]*cid.ID, error) + DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) error + GetEACLParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*client.EACLWithSignature, error) + SetEACLParam(ctx context.Context, table *eacl.Table, callParam CallParam) error + AnnounceContainerUsedSpaceParam(ctx context.Context, announce []container.UsedSpaceAnnouncement, callParam CallParam) error } type clientPack struct { - client client.Client - sessionToken *session.Token - healthy bool + client client.Client + healthy bool + address string +} + +type CallParam struct { + Key *ecdsa.PrivateKey + + Options []client.CallOption } var _ Pool = (*pool)(nil) @@ -102,13 +127,17 @@ var _ Pool = (*pool)(nil) type pool struct { lock sync.RWMutex sampler *Sampler + key *ecdsa.PrivateKey owner *owner.ID clientPacks []*clientPack cancel context.CancelFunc closedCh chan struct{} + cache *SessionCache } func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { + cache := NewCache() + clientPacks := make([]*clientPack, len(options.weights)) var atLeastOneHealthy bool for i, address := range options.addresses { @@ -126,8 +155,9 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { zap.Error(err)) } else if err == nil { healthy, atLeastOneHealthy = true, true + _ = cache.Put(formCacheKey(address, options.Key), st) } - clientPacks[i] = &clientPack{client: c, sessionToken: st, healthy: healthy} + clientPacks[i] = &clientPack{client: c, healthy: healthy, address: address} } if !atLeastOneHealthy { @@ -143,7 +173,15 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { ownerID := owner.NewIDFromNeo3Wallet(wallet) ctx, cancel := context.WithCancel(ctx) - pool := &pool{sampler: sampler, owner: ownerID, clientPacks: clientPacks, cancel: cancel, closedCh: make(chan struct{})} + pool := &pool{ + sampler: sampler, + key: options.Key, + owner: ownerID, + clientPacks: clientPacks, + cancel: cancel, + closedCh: make(chan struct{}), + cache: cache, + } go startRebalance(ctx, pool, options) return pool, nil } @@ -175,33 +213,34 @@ func updateNodesHealth(ctx context.Context, p *pool, options *BuilderOptions, bu go func(i int, client client.Client) { defer wg.Done() - var ( - tkn *session.Token - err error - ) ok := true tctx, c := context.WithTimeout(ctx, options.NodeRequestTimeout) defer c() - if _, err = client.EndpointInfo(tctx); err != nil { + if _, err := client.EndpointInfo(tctx); err != nil { ok = false bufferWeights[i] = 0 } + p.lock.RLock() + cp := *p.clientPacks[i] + p.lock.RUnlock() + if ok { bufferWeights[i] = options.weights[i] - p.lock.RLock() - if !p.clientPacks[i].healthy { - if tkn, err = client.CreateSession(ctx, options.SessionExpirationEpoch); err != nil { + if !cp.healthy { + if tkn, err := client.CreateSession(ctx, options.SessionExpirationEpoch); err != nil { ok = false bufferWeights[i] = 0 + } else { + _ = p.cache.Put(formCacheKey(cp.address, p.key), tkn) } } - p.lock.RUnlock() + } else { + p.cache.DeleteByPrefix(cp.address) } p.lock.Lock() if p.clientPacks[i].healthy != ok { p.clientPacks[i].healthy = ok - p.clientPacks[i].sessionToken = tkn healthyChanged = true } p.lock.Unlock() @@ -234,23 +273,33 @@ func adjustWeights(weights []float64) []float64 { } func (p *pool) Connection() (client.Client, *session.Token, error) { + cp, err := p.connection() + if err != nil { + return nil, nil, err + } + + token := p.cache.Get(formCacheKey(cp.address, p.key)) + return cp.client, token, nil +} + +func (p *pool) connection() (*clientPack, error) { p.lock.RLock() defer p.lock.RUnlock() if len(p.clientPacks) == 1 { cp := p.clientPacks[0] if cp.healthy { - return cp.client, cp.sessionToken, nil + return cp, nil } - return nil, nil, errors.New("no healthy client") + return nil, errors.New("no healthy client") } attempts := 3 * len(p.clientPacks) for k := 0; k < attempts; k++ { i := p.sampler.Next() if cp := p.clientPacks[i]; cp.healthy { - return cp.client, cp.sessionToken, nil + return cp, nil } } - return nil, nil, errors.New("no healthy client") + return nil, errors.New("no healthy client") } func (p *pool) OwnerID() *owner.ID { @@ -265,6 +314,36 @@ func (p *pool) conn(option []client.CallOption) (client.Client, []client.CallOpt return conn, append([]client.CallOption{client.WithSession(token)}, option...), nil } +func formCacheKey(address string, key *ecdsa.PrivateKey) string { + k := keys.PrivateKey{PrivateKey: *key} + return address + k.String() +} + +func (p *pool) connParam(ctx context.Context, param CallParam) (*clientPack, []client.CallOption, error) { + cp, err := p.connection() + if err != nil { + return nil, nil, err + } + + key := p.key + if param.Key != nil { + key = param.Key + } + + param.Options = append(param.Options, client.WithKey(key)) + cacheKey := formCacheKey(cp.address, key) + token := p.cache.Get(cacheKey) + if token == nil { + token, err = cp.client.CreateSession(ctx, math.MaxUint32, param.Options...) + if err != nil { + return nil, nil, err + } + _ = p.cache.Put(cacheKey, token) + } + + return cp, append([]client.CallOption{client.WithSession(token)}, param.Options...), nil +} + func (p *pool) PutObject(ctx context.Context, params *client.PutObjectParams, option ...client.CallOption) (*object.ID, error) { conn, options, err := p.conn(option) if err != nil { @@ -385,6 +464,166 @@ func (p *pool) AnnounceContainerUsedSpace(ctx context.Context, announce []contai return conn.AnnounceContainerUsedSpace(ctx, announce, options...) } +func (p *pool) checkSessionTokenErr(err error, address string) { + if err == nil { + return + } + + if strings.Contains(err.Error(), "session token does not exist") { + p.cache.DeleteByPrefix(address) + } +} + +func (p *pool) PutObjectParam(ctx context.Context, params *client.PutObjectParams, callParam CallParam) (*object.ID, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.PutObject(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) DeleteObjectParam(ctx context.Context, params *client.DeleteObjectParams, callParam CallParam) error { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return err + } + err = cp.client.DeleteObject(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return err +} + +func (p *pool) GetObjectParam(ctx context.Context, params *client.GetObjectParams, callParam CallParam) (*object.Object, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.GetObject(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) GetObjectHeaderParam(ctx context.Context, params *client.ObjectHeaderParams, callParam CallParam) (*object.Object, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.GetObjectHeader(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) ObjectPayloadRangeDataParam(ctx context.Context, params *client.RangeDataParams, callParam CallParam) ([]byte, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.ObjectPayloadRangeData(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) ObjectPayloadRangeSHA256Param(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][32]byte, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.ObjectPayloadRangeSHA256(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) ObjectPayloadRangeTZParam(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][64]byte, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.ObjectPayloadRangeTZ(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) SearchObjectParam(ctx context.Context, params *client.SearchObjectParams, callParam CallParam) ([]*object.ID, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.SearchObject(ctx, params, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) PutContainerParam(ctx context.Context, cnr *container.Container, callParam CallParam) (*cid.ID, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.PutContainer(ctx, cnr, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) GetContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*container.Container, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.GetContainer(ctx, cid, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) ListContainersParam(ctx context.Context, ownerID *owner.ID, callParam CallParam) ([]*cid.ID, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.ListContainers(ctx, ownerID, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) error { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return err + } + err = cp.client.DeleteContainer(ctx, cid, options...) + p.checkSessionTokenErr(err, cp.address) + return err +} + +func (p *pool) GetEACLParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*client.EACLWithSignature, error) { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return nil, err + } + res, err := cp.client.GetEACL(ctx, cid, options...) + p.checkSessionTokenErr(err, cp.address) + return res, err +} + +func (p *pool) SetEACLParam(ctx context.Context, table *eacl.Table, callParam CallParam) error { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return err + } + err = cp.client.SetEACL(ctx, table, options...) + p.checkSessionTokenErr(err, cp.address) + return err +} + +func (p *pool) AnnounceContainerUsedSpaceParam(ctx context.Context, announce []container.UsedSpaceAnnouncement, callParam CallParam) error { + cp, options, err := p.connParam(ctx, callParam) + if err != nil { + return err + } + err = cp.client.AnnounceContainerUsedSpace(ctx, announce, options...) + p.checkSessionTokenErr(err, cp.address) + return err +} + func (p *pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollParams *ContainerPollingParams) error { conn, _, err := p.Connection() if err != nil { diff --git a/pool/pool_test.go b/pool/pool_test.go index fb4aede..f9ff43f 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -307,6 +307,10 @@ func TestTwoFailed(t *testing.T) { require.Contains(t, err.Error(), "no healthy") } +func TestSessionCache(t *testing.T) { + +} + func newToken(t *testing.T) *session.Token { tok := session.NewToken() uid, err := uuid.New().MarshalBinary() @@ -324,12 +328,17 @@ func TestWaitPresence(t *testing.T) { mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().GetContainer(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + key, err := keys.NewPrivateKey() + require.NoError(t, err) + p := &pool{ sampler: NewSampler([]float64{1}, rand.NewSource(0)), clientPacks: []*clientPack{{ client: mockClient, healthy: true, }}, + key: &key.PrivateKey, + cache: NewCache(), } t.Run("context canceled", func(t *testing.T) { diff --git a/pool/sampler_test.go b/pool/sampler_test.go index 5b259d5..e71b559 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "testing" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neofs-sdk-go/client" "github.com/nspcc-dev/neofs-sdk-go/session" "github.com/stretchr/testify/require" @@ -73,11 +74,16 @@ func TestHealthyReweight(t *testing.T) { buffer = make([]float64, len(weights)) ) + key, err := keys.NewPrivateKey() + require.NoError(t, err) + p := &pool{ sampler: NewSampler(weights, rand.NewSource(0)), clientPacks: []*clientPack{ - {client: newNetmapMock(names[0], true), healthy: true}, - {client: newNetmapMock(names[1], false), healthy: true}}, + {client: newNetmapMock(names[0], true), healthy: true, address: "address0"}, + {client: newNetmapMock(names[1], false), healthy: true, address: "address1"}}, + cache: NewCache(), + key: &key.PrivateKey, } // check getting first node connection before rebalance happened @@ -105,8 +111,6 @@ func TestHealthyReweight(t *testing.T) { require.NoError(t, err) mock0 = connection0.(clientMock) require.Equal(t, names[0], mock0.name) - - require.NotNil(t, p.clientPacks[0].sessionToken) } func TestHealthyNoReweight(t *testing.T) {