diff --git a/pool/pool.go b/pool/pool.go index d69fabcf..a5f95700 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -59,14 +59,14 @@ type Client interface { // BuilderOptions contains options used to build connection pool. type BuilderOptions struct { - Key *ecdsa.PrivateKey - Logger *zap.Logger - NodeConnectionTimeout time.Duration - NodeRequestTimeout time.Duration - ClientRebalanceInterval time.Duration - SessionExpirationEpoch uint64 - nodesParams []*NodesParam - clientBuilder func(opts ...client.Option) (Client, error) + Key *ecdsa.PrivateKey + Logger *zap.Logger + NodeConnectionTimeout time.Duration + NodeRequestTimeout time.Duration + ClientRebalanceInterval time.Duration + SessionExpirationDuration uint64 + nodesParams []*NodesParam + clientBuilder func(opts ...client.Option) (Client, error) } type NodesParam struct { @@ -239,12 +239,13 @@ func cfgFromOpts(opts ...CallOption) *callConfig { var _ Pool = (*pool)(nil) type pool struct { - innerPools []*innerPool - key *ecdsa.PrivateKey - owner *owner.ID - cancel context.CancelFunc - closedCh chan struct{} - cache *SessionCache + innerPools []*innerPool + key *ecdsa.PrivateKey + owner *owner.ID + cancel context.CancelFunc + closedCh chan struct{} + cache *SessionCache + stokenDuration uint64 } type innerPool struct { @@ -264,10 +265,6 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { inner := make([]*innerPool, len(options.nodesParams)) var atLeastOneHealthy bool - var cliPrm client.CreateSessionPrm - - cliPrm.SetExp(options.SessionExpirationEpoch) - for i, params := range options.nodesParams { clientPacks := make([]*clientPack, len(params.weights)) for j, address := range params.addresses { @@ -279,7 +276,7 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { return nil, err } var healthy bool - cliRes, err := c.CreateSession(ctx, cliPrm) + cliRes, err := createSessionTokenForDuration(ctx, c, options.SessionExpirationDuration) if err != nil && options.Logger != nil { options.Logger.Warn("failed to create neofs session token for client", zap.String("address", address), @@ -306,12 +303,13 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { ctx, cancel := context.WithCancel(ctx) pool := &pool{ - innerPools: inner, - key: options.Key, - owner: ownerID, - cancel: cancel, - closedCh: make(chan struct{}), - cache: cache, + innerPools: inner, + key: options.Key, + owner: ownerID, + cancel: cancel, + closedCh: make(chan struct{}), + cache: cache, + stokenDuration: options.SessionExpirationDuration, } go startRebalance(ctx, pool, options) return pool, nil @@ -359,12 +357,7 @@ func updateInnerNodesHealth(ctx context.Context, pool *pool, i int, options *Bui healthyChanged := false wg := sync.WaitGroup{} - var ( - prmEndpoint client.EndpointInfoPrm - prmSession client.CreateSessionPrm - ) - - prmSession.SetExp(options.SessionExpirationEpoch) + var prmEndpoint client.EndpointInfoPrm for j, cPack := range p.clientPacks { wg.Add(1) @@ -385,7 +378,8 @@ func updateInnerNodesHealth(ctx context.Context, pool *pool, i int, options *Bui if ok { bufferWeights[j] = options.nodesParams[i].weights[j] if !cp.healthy { - if cliRes, err := cli.CreateSession(ctx, prmSession); err != nil { + cliRes, err := createSessionTokenForDuration(ctx, cli, options.SessionExpirationDuration) + if err != nil { ok = false bufferWeights[j] = 0 } else { @@ -502,11 +496,7 @@ func (p *pool) conn(ctx context.Context, cfg *callConfig) (*clientPack, []client cacheKey := formCacheKey(cp.address, key) sessionToken = p.cache.Get(cacheKey) if sessionToken == nil { - var cliPrm client.CreateSessionPrm - - cliPrm.SetExp(math.MaxUint32) - - cliRes, err := cp.client.CreateSession(ctx, cliPrm) + cliRes, err := createSessionTokenForDuration(ctx, cp.client, p.stokenDuration) if err != nil { return nil, nil, err } @@ -541,6 +531,24 @@ func (p *pool) checkSessionTokenErr(err error, address string) bool { return false } +func createSessionTokenForDuration(ctx context.Context, c Client, dur uint64) (*client.CreateSessionRes, error) { + ni, err := c.NetworkInfo(ctx, client.NetworkInfoPrm{}) + if err != nil { + return nil, err + } + + epoch := ni.Info().CurrentEpoch() + + var prm client.CreateSessionPrm + if math.MaxUint64-epoch < dur { + prm.SetExp(math.MaxUint64) + } else { + prm.SetExp(epoch + dur) + } + + return c.CreateSession(ctx, prm) +} + func (p *pool) PutObject(ctx context.Context, params *client.PutObjectParams, opts ...CallOption) (*oid.ID, error) { cfg := cfgFromOpts(append(opts, useDefaultSession())...) cp, options, err := p.conn(ctx, cfg) diff --git a/pool/pool_test.go b/pool/pool_test.go index 6f3a5a6f..ba792c70 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -48,6 +48,7 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error session")).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.EndpointInfoRes{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() return mockClient, nil } @@ -94,10 +95,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { }).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() mockClient2 := NewMockClient(ctrl2) mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() if clientCount == 0 { return mockClient, nil @@ -153,6 +156,7 @@ func TestOneNode(t *testing.T) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(tok, nil) mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.EndpointInfoRes{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() return mockClient, nil } @@ -190,6 +194,7 @@ func TestTwoNodes(t *testing.T) { return tok, err }) mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.EndpointInfoRes{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() return mockClient, nil } @@ -228,6 +233,7 @@ func TestOneOfTwoFailed(t *testing.T) { return tok, nil }).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() mockClient2 := NewMockClient(ctrl2) mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { @@ -238,6 +244,9 @@ func TestOneOfTwoFailed(t *testing.T) { mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.EndpointInfoRes, error) { return nil, fmt.Errorf("error") }).AnyTimes() + mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.NetworkInfoRes, error) { + return nil, fmt.Errorf("error") + }).AnyTimes() if clientCount == 0 { return mockClient, nil @@ -277,6 +286,7 @@ func TestTwoFailed(t *testing.T) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes() return mockClient, nil } @@ -379,6 +389,7 @@ func TestPriority(t *testing.T) { return tok, nil }).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes() mockClient2 := NewMockClient(ctrl2) mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { @@ -387,6 +398,7 @@ func TestPriority(t *testing.T) { return tok, nil }).AnyTimes() mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() if clientCount == 0 { return mockClient, nil @@ -489,6 +501,7 @@ func TestSessionTokenOwner(t *testing.T) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&client.CreateSessionRes{}, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.EndpointInfoRes{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() return mockClient, nil } @@ -527,6 +540,7 @@ func TestWaitPresence(t *testing.T) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.NetworkInfoRes{}, nil).AnyTimes() mockClient.EXPECT().GetContainer(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() cache, err := NewCache()