diff --git a/pool/cache.go b/pool/cache.go index c90fa47..151aa17 100644 --- a/pool/cache.go +++ b/pool/cache.go @@ -14,7 +14,8 @@ type sessionCache struct { } type cacheValue struct { - token session.Object + token session.Object + expiration uint64 } func newCache() (*sessionCache, error) { @@ -44,9 +45,10 @@ func (c *sessionCache) Get(key string) (session.Object, bool) { return value.token, true } -func (c *sessionCache) Put(key string, token session.Object) bool { +func (c *sessionCache) Put(key string, token session.Object, exp uint64) bool { return c.cache.Add(key, &cacheValue{ - token: token, + token: token, + expiration: exp, }) } @@ -67,5 +69,5 @@ func (c *sessionCache) updateEpoch(newEpoch uint64) { func (c *sessionCache) expired(val *cacheValue) bool { epoch := atomic.LoadUint64(&c.currentEpoch) - return val.token.ExpiredAt(epoch) + return val.expiration <= epoch } diff --git a/pool/cache_test.go b/pool/cache_test.go index 1d425c6..ae5cf21 100644 --- a/pool/cache_test.go +++ b/pool/cache_test.go @@ -23,7 +23,7 @@ func TestSessionCache_GetUnmodifiedToken(t *testing.T) { cache, err := newCache() require.NoError(t, err) - cache.Put(key, target) + cache.Put(key, target, 10) value, ok := cache.Get(key) require.True(t, ok) check(t, value, "before sign") diff --git a/pool/pool.go b/pool/pool.go index 0144101..8e3a092 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -10,6 +10,7 @@ import ( "math" "math/rand" "sort" + "strings" "sync" "time" @@ -1617,7 +1618,7 @@ func (p *Pool) Dial(ctx context.Context) error { } var st session.Object - err := initSessionForDuration(ctx, &st, clients[j], p.rebalanceParams.sessionExpirationDuration, *p.key) + exp, err := initSessionForDuration(ctx, &st, clients[j], p.rebalanceParams.sessionExpirationDuration, *p.key) if err != nil { clients[j].setUnhealthy() if p.logger != nil { @@ -1627,7 +1628,7 @@ func (p *Pool) Dial(ctx context.Context) error { continue } - _ = p.cache.Put(formCacheKey(addr, p.key), st) + _ = p.cache.Put(formCacheKey(addr, p.key), st, exp) atLeastOneHealthy = true } source := rand.NewSource(time.Now().UnixNano()) @@ -1857,7 +1858,7 @@ func (p *Pool) checkSessionTokenErr(err error, address string) bool { return false } - if sdkClient.IsErrSessionNotFound(err) || sdkClient.IsErrSessionExpired(err) { + if sdkClient.IsErrSessionNotFound(err) || sdkClient.IsErrSessionExpired(err) || strings.Contains(err.Error(), "token is invalid") { p.cache.DeleteByPrefix(address) return true } @@ -1865,10 +1866,10 @@ func (p *Pool) checkSessionTokenErr(err error, address string) bool { return false } -func initSessionForDuration(ctx context.Context, dst *session.Object, c client, dur uint64, ownerKey ecdsa.PrivateKey) error { +func initSessionForDuration(ctx context.Context, dst *session.Object, c client, dur uint64, ownerKey ecdsa.PrivateKey) (uint64, error) { ni, err := c.networkInfo(ctx, prmNetworkInfo{}) if err != nil { - return err + return 0, err } epoch := ni.CurrentEpoch() @@ -1885,28 +1886,28 @@ func initSessionForDuration(ctx context.Context, dst *session.Object, c client, res, err := c.sessionCreate(ctx, prm) if err != nil { - return err + return 0, err } var id uuid.UUID err = id.UnmarshalBinary(res.id) if err != nil { - return fmt.Errorf("invalid session token ID: %w", err) + return 0, fmt.Errorf("invalid session token ID: %w", err) } var key neofsecdsa.PublicKey err = key.Decode(res.sessionKey) if err != nil { - return fmt.Errorf("invalid public session key: %w", err) + return 0, fmt.Errorf("invalid public session key: %w", err) } dst.SetID(id) dst.SetAuthKey(&key) dst.SetExp(exp) - return nil + return exp, nil } type callContext struct { @@ -1969,13 +1970,13 @@ func (p *Pool) openDefaultSession(ctx *callContext) error { tok, ok := p.cache.Get(cacheKey) if !ok { // init new session - err := initSessionForDuration(ctx, &tok, ctx.client, p.stokenDuration, *ctx.key) + exp, err := initSessionForDuration(ctx, &tok, ctx.client, p.stokenDuration, *ctx.key) if err != nil { return fmt.Errorf("session API client: %w", err) } // cache the opened session - p.cache.Put(cacheKey, tok) + p.cache.Put(cacheKey, tok, exp) } tok.ForVerb(ctx.sessionVerb)