diff --git a/pool/cache.go b/pool/cache.go index 5c39888..05bbcae 100644 --- a/pool/cache.go +++ b/pool/cache.go @@ -9,21 +9,22 @@ import ( ) type sessionCache struct { - cache *lru.Cache[string, *cacheValue] - currentEpoch atomic.Uint64 + cache *lru.Cache[string, *cacheValue] + currentEpoch atomic.Uint64 + tokenDuration uint64 } type cacheValue struct { token session.Object } -func newCache() (*sessionCache, error) { +func newCache(tokenDuration uint64) (*sessionCache, error) { cache, err := lru.New[string, *cacheValue](100) if err != nil { return nil, err } - return &sessionCache{cache: cache}, nil + return &sessionCache{cache: cache, tokenDuration: tokenDuration}, nil } // Get returns a copy of the session token from the cache without signature @@ -66,8 +67,12 @@ func (c *sessionCache) updateEpoch(newEpoch uint64) { func (c *sessionCache) expired(val *cacheValue) bool { epoch := c.currentEpoch.Load() - // use epoch+1 (clear cache beforehand) to prevent 'expired session token' error right after epoch tick - return val.token.ExpiredAt(epoch + 1) + preExpiredDur := c.tokenDuration / 2 + if preExpiredDur == 0 { + preExpiredDur = 1 + } + + return val.token.ExpiredAt(epoch + preExpiredDur) } func (c *sessionCache) Epoch() uint64 { diff --git a/pool/cache_test.go b/pool/cache_test.go index c1f12c8..d49c4d8 100644 --- a/pool/cache_test.go +++ b/pool/cache_test.go @@ -20,7 +20,7 @@ func TestSessionCache_GetUnmodifiedToken(t *testing.T) { require.False(t, tok.VerifySignature(), extra) } - cache, err := newCache() + cache, err := newCache(0) require.NoError(t, err) cache.Put(key, target) diff --git a/pool/pool.go b/pool/pool.go index 7f100cf..0cffc30 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -1947,7 +1947,7 @@ type innerPool struct { } const ( - defaultSessionTokenExpirationDuration = 100 // in blocks + defaultSessionTokenExpirationDuration = 100 // in epochs defaultErrorThreshold = 100 defaultRebalanceInterval = 15 * time.Second @@ -1969,7 +1969,7 @@ func NewPool(options InitParameters) (*Pool, error) { return nil, err } - cache, err := newCache() + cache, err := newCache(options.sessionExpirationDuration) if err != nil { return nil, fmt.Errorf("couldn't create cache: %w", err) } @@ -2087,6 +2087,10 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { params.nodeStreamTimeout = defaultStreamTimeout } + if cache.tokenDuration == 0 { + cache.tokenDuration = defaultSessionTokenExpirationDuration + } + if params.isMissingClientBuilder() { params.setClientBuilder(func(addr string) client { var prm wrapperPrm diff --git a/pool/sampler_test.go b/pool/sampler_test.go index 5ea2326..5ece768 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -47,7 +47,7 @@ func TestHealthyReweight(t *testing.T) { buffer = make([]float64, len(weights)) ) - cache, err := newCache() + cache, err := newCache(0) require.NoError(t, err) client1 := newMockClient(names[0], *newPrivateKey(t))