diff --git a/pool/mock_test.go b/pool/mock_test.go index 07e9ce3..592b9a1 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/nspcc-dev/neofs-sdk-go/Pool (interfaces: Client) +// Source: github.com/nspcc-dev/neofs-sdk-go/Pool (interfaces: client) // Package pool is a generated GoMock package. package pool @@ -12,7 +12,7 @@ import ( client0 "github.com/nspcc-dev/neofs-sdk-go/client" ) -// MockClient is a mock of Client interface. +// MockClient is a mock of client interface. type MockClient struct { ctrl *gomock.Controller recorder *MockClientMockRecorder diff --git a/pool/pool.go b/pool/pool.go index 456a8d7..0d3f31d 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -17,10 +17,11 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" sessionv2 "github.com/nspcc-dev/neofs-api-go/v2/session" "github.com/nspcc-dev/neofs-sdk-go/accounting" - "github.com/nspcc-dev/neofs-sdk-go/client" + sdkClient "github.com/nspcc-dev/neofs-sdk-go/client" "github.com/nspcc-dev/neofs-sdk-go/container" cid "github.com/nspcc-dev/neofs-sdk-go/container/id" "github.com/nspcc-dev/neofs-sdk-go/eacl" + "github.com/nspcc-dev/neofs-sdk-go/netmap" "github.com/nspcc-dev/neofs-sdk-go/object" "github.com/nspcc-dev/neofs-sdk-go/object/address" oid "github.com/nspcc-dev/neofs-sdk-go/object/id" @@ -30,27 +31,27 @@ import ( "go.uber.org/zap" ) -// Client is a wrapper for client.Client to generate mock. -type Client interface { - BalanceGet(context.Context, client.PrmBalanceGet) (*client.ResBalanceGet, error) - ContainerPut(context.Context, client.PrmContainerPut) (*client.ResContainerPut, error) - ContainerGet(context.Context, client.PrmContainerGet) (*client.ResContainerGet, error) - ContainerList(context.Context, client.PrmContainerList) (*client.ResContainerList, error) - ContainerDelete(context.Context, client.PrmContainerDelete) (*client.ResContainerDelete, error) - ContainerEACL(context.Context, client.PrmContainerEACL) (*client.ResContainerEACL, error) - ContainerSetEACL(context.Context, client.PrmContainerSetEACL) (*client.ResContainerSetEACL, error) - EndpointInfo(context.Context, client.PrmEndpointInfo) (*client.ResEndpointInfo, error) - NetworkInfo(context.Context, client.PrmNetworkInfo) (*client.ResNetworkInfo, error) - ObjectPutInit(context.Context, client.PrmObjectPutInit) (*client.ObjectWriter, error) - ObjectDelete(context.Context, client.PrmObjectDelete) (*client.ResObjectDelete, error) - ObjectGetInit(context.Context, client.PrmObjectGet) (*client.ObjectReader, error) - ObjectHead(context.Context, client.PrmObjectHead) (*client.ResObjectHead, error) - ObjectRangeInit(context.Context, client.PrmObjectRange) (*client.ObjectRangeReader, error) - ObjectSearchInit(context.Context, client.PrmObjectSearch) (*client.ObjectListReader, error) - SessionCreate(context.Context, client.PrmSessionCreate) (*client.ResSessionCreate, error) +// client is a wrapper for sdkClient.Client to generate mock. +type client interface { + BalanceGet(context.Context, sdkClient.PrmBalanceGet) (*sdkClient.ResBalanceGet, error) + ContainerPut(context.Context, sdkClient.PrmContainerPut) (*sdkClient.ResContainerPut, error) + ContainerGet(context.Context, sdkClient.PrmContainerGet) (*sdkClient.ResContainerGet, error) + ContainerList(context.Context, sdkClient.PrmContainerList) (*sdkClient.ResContainerList, error) + ContainerDelete(context.Context, sdkClient.PrmContainerDelete) (*sdkClient.ResContainerDelete, error) + ContainerEACL(context.Context, sdkClient.PrmContainerEACL) (*sdkClient.ResContainerEACL, error) + ContainerSetEACL(context.Context, sdkClient.PrmContainerSetEACL) (*sdkClient.ResContainerSetEACL, error) + EndpointInfo(context.Context, sdkClient.PrmEndpointInfo) (*sdkClient.ResEndpointInfo, error) + NetworkInfo(context.Context, sdkClient.PrmNetworkInfo) (*sdkClient.ResNetworkInfo, error) + ObjectPutInit(context.Context, sdkClient.PrmObjectPutInit) (*sdkClient.ObjectWriter, error) + ObjectDelete(context.Context, sdkClient.PrmObjectDelete) (*sdkClient.ResObjectDelete, error) + ObjectGetInit(context.Context, sdkClient.PrmObjectGet) (*sdkClient.ObjectReader, error) + ObjectHead(context.Context, sdkClient.PrmObjectHead) (*sdkClient.ResObjectHead, error) + ObjectRangeInit(context.Context, sdkClient.PrmObjectRange) (*sdkClient.ObjectRangeReader, error) + ObjectSearchInit(context.Context, sdkClient.PrmObjectSearch) (*sdkClient.ObjectListReader, error) + SessionCreate(context.Context, sdkClient.PrmSessionCreate) (*sdkClient.ResSessionCreate, error) } -// InitParameters contains options used to create connection Pool. +// InitParameters contains values used to initialize connection Pool. type InitParameters struct { Key *ecdsa.PrivateKey Logger *zap.Logger @@ -61,7 +62,7 @@ type InitParameters struct { SessionExpirationDuration uint64 NodeParams []NodeParam - clientBuilder func(endpoint string) (Client, error) + clientBuilder func(endpoint string) (client, error) } type rebalanceParameters struct { @@ -99,7 +100,7 @@ func DefaultPollingParams() *ContainerPollingParams { } type clientPack struct { - client Client + client client healthy bool address string } @@ -325,7 +326,7 @@ func (x *PrmBalanceGet) SetOwnerID(ownerID *owner.ID) { // Pool represents virtual connection to the NeoFS network to communicate // with multiple NeoFS servers without thinking about switching between servers // due to load balancing proportions or their unavailability. -// It is designed to provide a convenient abstraction from the multiple client.Client types. +// It is designed to provide a convenient abstraction from the multiple sdkClient.client types. // // Pool can be created and initialized using NewPool function. // Before executing the NeoFS operations using the Pool, connection to the @@ -338,9 +339,9 @@ func (x *PrmBalanceGet) SetOwnerID(ownerID *owner.ID) { // // Each method which produces a NeoFS API call may return an error. // Status of underlying server response is casted to built-in error instance. -// Certain statuses can be checked using `client` and standard `errors` packages. +// Certain statuses can be checked using `sdkClient` and standard `errors` packages. // Note that package provides some helper functions to work with status returns -// (e.g. client.IsErrContainerNotFound, client.IsErrObjectNotFound). +// (e.g. sdkClient.IsErrContainerNotFound, sdkClient.IsErrObjectNotFound). // // See pool package overview to get some examples. type Pool struct { @@ -353,7 +354,7 @@ type Pool struct { stokenDuration uint64 stokenThreshold time.Duration rebalanceParams rebalanceParameters - clientBuilder func(endpoint string) (Client, error) + clientBuilder func(endpoint string) (client, error) logger *zap.Logger } @@ -475,16 +476,16 @@ func fillDefaultInitParams(params *InitParameters) { } if params.clientBuilder == nil { - params.clientBuilder = func(addr string) (Client, error) { - var c client.Client + params.clientBuilder = func(addr string) (client, error) { + var c sdkClient.Client - var prmInit client.PrmInit + var prmInit sdkClient.PrmInit prmInit.ResolveNeoFSFailures() prmInit.SetDefaultPrivateKey(*params.Key) c.Init(prmInit) - var prmDial client.PrmDial + var prmDial sdkClient.PrmDial prmDial.SetServerURI(addr) prmDial.SetTimeout(params.NodeConnectionTimeout) @@ -565,11 +566,11 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights healthyChanged := false wg := sync.WaitGroup{} - var prmEndpoint client.PrmEndpointInfo + var prmEndpoint sdkClient.PrmEndpointInfo for j, cPack := range pool.clientPacks { wg.Add(1) - go func(j int, cli Client) { + go func(j int, cli client) { defer wg.Done() ok := true tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) @@ -634,16 +635,6 @@ func adjustWeights(weights []float64) []float64 { return adjusted } -func (p *Pool) Connection() (Client, *session.Token, error) { - cp, err := p.connection() - if err != nil { - return nil, nil, err - } - - tok := p.cache.Get(formCacheKey(cp.address, p.key)) - return cp.client, tok, nil -} - func (p *Pool) connection() (*clientPack, error) { for _, inner := range p.innerPools { cp, err := inner.connection() @@ -732,15 +723,15 @@ func (p *Pool) checkSessionTokenErr(err error, address string) bool { return false } -func createSessionTokenForDuration(ctx context.Context, c Client, dur uint64) (*client.ResSessionCreate, error) { - ni, err := c.NetworkInfo(ctx, client.PrmNetworkInfo{}) +func createSessionTokenForDuration(ctx context.Context, c client, dur uint64) (*sdkClient.ResSessionCreate, error) { + ni, err := c.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) if err != nil { return nil, err } epoch := ni.Info().CurrentEpoch() - var prm client.PrmSessionCreate + var prm sdkClient.PrmSessionCreate if math.MaxUint64-epoch < dur { prm.SetExp(math.MaxUint64) } else { @@ -773,7 +764,7 @@ type callContext struct { // base context for RPC context.Context - client Client + client client // client endpoint endpoint string @@ -921,7 +912,7 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (*oid.ID, error) return nil, fmt.Errorf("init call context") } - var cliPrm client.PrmObjectPutInit + var cliPrm sdkClient.PrmObjectPutInit wObj, err := ctxCall.client.ObjectPutInit(ctx, cliPrm) if err != nil { @@ -1009,7 +1000,7 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error { prm.useVerb(sessionv2.ObjectVerbDelete) prm.useAddress(&prm.addr) - var cliPrm client.PrmObjectDelete + var cliPrm sdkClient.PrmObjectDelete var cc callContextWithRetry @@ -1041,16 +1032,16 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error { }) } -type objectReadCloser client.ObjectReader +type objectReadCloser sdkClient.ObjectReader // Read implements io.Reader of the object payload. func (x *objectReadCloser) Read(p []byte) (int, error) { - return (*client.ObjectReader)(x).Read(p) + return (*sdkClient.ObjectReader)(x).Read(p) } // Close implements io.Closer of the object payload. func (x *objectReadCloser) Close() error { - _, err := (*client.ObjectReader)(x).Close() + _, err := (*sdkClient.ObjectReader)(x).Close() return err } @@ -1067,7 +1058,7 @@ func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (*ResGetObject, prm.useVerb(sessionv2.ObjectVerbGet) prm.useAddress(&prm.addr) - var cliPrm client.PrmObjectGet + var cliPrm sdkClient.PrmObjectGet var cc callContextWithRetry @@ -1119,7 +1110,7 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (*object.Objec prm.useVerb(sessionv2.ObjectVerbHead) prm.useAddress(&prm.addr) - var cliPrm client.PrmObjectHead + var cliPrm sdkClient.PrmObjectHead var cc callContextWithRetry @@ -1168,7 +1159,7 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (*object.Objec // Must be initialized using Pool.ObjectRange, any other // usage is unsafe. type ResObjectRange struct { - payload *client.ObjectRangeReader + payload *sdkClient.ObjectRangeReader } // Read implements io.Reader of the object payload. @@ -1190,7 +1181,7 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (*ResObjectR prm.useVerb(sessionv2.ObjectVerbRange) prm.useAddress(&prm.addr) - var cliPrm client.PrmObjectRange + var cliPrm sdkClient.PrmObjectRange cliPrm.SetOffset(prm.off) cliPrm.SetLength(prm.ln) @@ -1238,7 +1229,7 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (*ResObjectR // // Must be initialized using Pool.SearchObjects, any other usage is unsafe. type ResObjectSearch struct { - r *client.ObjectListReader + r *sdkClient.ObjectListReader } // Read reads another list of the object identifiers. @@ -1280,7 +1271,7 @@ func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (*ResObje prm.useVerb(sessionv2.ObjectVerbSearch) prm.useAddress(newAddressFromCnrID(&prm.cnrID)) - var cliPrm client.PrmObjectSearch + var cliPrm sdkClient.PrmObjectSearch cliPrm.InContainer(prm.cnrID) cliPrm.SetFilters(prm.filters) @@ -1328,7 +1319,7 @@ func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (*cid.ID, return nil, err } - var cliPrm client.PrmContainerPut + var cliPrm sdkClient.PrmContainerPut if prm.cnr != nil { cliPrm.SetContainer(*prm.cnr) @@ -1349,7 +1340,7 @@ func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (*containe return nil, err } - var cliPrm client.PrmContainerGet + var cliPrm sdkClient.PrmContainerGet if prm.cnrID != nil { cliPrm.SetContainer(*prm.cnrID) @@ -1370,7 +1361,7 @@ func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid. return nil, err } - var cliPrm client.PrmContainerList + var cliPrm sdkClient.PrmContainerList if prm.ownerID != nil { cliPrm.SetAccount(*prm.ownerID) @@ -1396,7 +1387,7 @@ func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) erro return err } - var cliPrm client.PrmContainerDelete + var cliPrm sdkClient.PrmContainerDelete if prm.cnrID != nil { cliPrm.SetContainer(*prm.cnrID) @@ -1420,7 +1411,7 @@ func (p *Pool) GetEACL(ctx context.Context, prm PrmContainerEACL) (*eacl.Table, return nil, err } - var cliPrm client.PrmContainerEACL + var cliPrm sdkClient.PrmContainerEACL if prm.cnrID != nil { cliPrm.SetContainer(*prm.cnrID) @@ -1446,7 +1437,7 @@ func (p *Pool) SetEACL(ctx context.Context, prm PrmContainerSetEACL) error { return err } - var cliPrm client.PrmContainerSetEACL + var cliPrm sdkClient.PrmContainerSetEACL if prm.table != nil { cliPrm.SetTable(*prm.table) @@ -1466,7 +1457,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (*accounting.Deci return nil, err } - var cliPrm client.PrmBalanceGet + var cliPrm sdkClient.PrmBalanceGet if prm.ownerID != nil { cliPrm.SetAccount(*prm.ownerID) @@ -1482,7 +1473,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (*accounting.Deci // WaitForContainerPresence waits until the container is found on the NeoFS network. func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollParams *ContainerPollingParams) error { - conn, _, err := p.Connection() + cp, err := p.connection() if err != nil { return err } @@ -1493,7 +1484,7 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa wdone := wctx.Done() done := ctx.Done() - var cliPrm client.PrmContainerGet + var cliPrm sdkClient.PrmContainerGet if cid != nil { cliPrm.SetContainer(*cid) @@ -1506,7 +1497,7 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa case <-wdone: return wctx.Err() case <-ticker.C: - _, err = conn.ContainerGet(ctx, cliPrm) + _, err = cp.client.ContainerGet(ctx, cliPrm) if err == nil { return nil } @@ -1515,6 +1506,20 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa } } +func (p *Pool) NetworkInfo(ctx context.Context) (*netmap.NetworkInfo, error) { + cp, err := p.connection() + if err != nil { + return nil, err + } + + res, err := cp.client.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) + if err != nil { // here err already carries both status and client errors + return nil, err + } + + return res.Info(), nil +} + // Close closes the Pool and releases all the associated resources. func (p *Pool) Close() { p.cancel() @@ -1522,12 +1527,12 @@ func (p *Pool) Close() { } // creates new session token from SessionCreate call result. -func (p *Pool) newSessionToken(cliRes *client.ResSessionCreate) *session.Token { +func (p *Pool) newSessionToken(cliRes *sdkClient.ResSessionCreate) *session.Token { return sessionTokenForOwner(p.owner, cliRes) } // creates new session token with specified owner from SessionCreate call result. -func sessionTokenForOwner(id *owner.ID, cliRes *client.ResSessionCreate) *session.Token { +func sessionTokenForOwner(id *owner.ID, cliRes *sdkClient.ResSessionCreate) *session.Token { st := session.NewToken() st.SetOwnerID(id) st.SetID(cliRes.ID()) diff --git a/pool/pool_test.go b/pool/pool_test.go index 9f468d7..8e90ec1 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -13,7 +13,7 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" - "github.com/nspcc-dev/neofs-sdk-go/client" + sdkClient "github.com/nspcc-dev/neofs-sdk-go/client" "github.com/nspcc-dev/neofs-sdk-go/netmap" "github.com/nspcc-dev/neofs-sdk-go/object" "github.com/nspcc-dev/neofs-sdk-go/object/address" @@ -24,7 +24,7 @@ import ( ) func TestBuildPoolClientFailed(t *testing.T) { - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { return nil, fmt.Errorf("error") } @@ -46,11 +46,11 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) { ni := &netmap.NodeInfo{} ni.SetAddresses("addr1", "addr2") - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error session")).AnyTimes() - mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() return mockClient, nil } @@ -83,7 +83,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { var expectedToken *session.Token clientCount := -1 - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { clientCount++ mockClient := NewMockClient(ctrl) mockInvokes := 0 @@ -97,12 +97,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { }).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() mockClient2 := NewMockClient(ctrl2) mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() if clientCount == 0 { return mockClient, nil @@ -130,8 +130,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { t.Cleanup(clientPool.Close) condition := func() bool { - _, st, err := clientPool.Connection() - return err == nil && st == expectedToken + cp, err := clientPool.connection() + if err != nil { + return false + } + st := clientPool.cache.Get(formCacheKey(cp.address, clientPool.key)) + return st == expectedToken } require.Never(t, condition, 900*time.Millisecond, 100*time.Millisecond) require.Eventually(t, condition, 3*time.Second, 300*time.Millisecond) @@ -155,11 +159,11 @@ func TestOneNode(t *testing.T) { require.NoError(t, err) tok.SetID(uid) - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(tok, nil) - mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() return mockClient, nil } @@ -175,8 +179,9 @@ func TestOneNode(t *testing.T) { require.NoError(t, err) t.Cleanup(pool.Close) - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Equal(t, tok, st) } @@ -186,7 +191,7 @@ func TestTwoNodes(t *testing.T) { ctrl := gomock.NewController(t) var tokens []*session.Token - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { tok := session.NewToken() @@ -196,8 +201,8 @@ func TestTwoNodes(t *testing.T) { tokens = append(tokens, tok) return tok, err }) - mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() return mockClient, nil } @@ -216,8 +221,9 @@ func TestTwoNodes(t *testing.T) { require.NoError(t, err) t.Cleanup(pool.Close) - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Contains(t, tokens, st) } @@ -229,7 +235,7 @@ func TestOneOfTwoFailed(t *testing.T) { var tokens []*session.Token clientCount := -1 - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { clientCount++ mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { @@ -238,7 +244,7 @@ func TestOneOfTwoFailed(t *testing.T) { return tok, nil }).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() mockClient2 := NewMockClient(ctrl2) mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { @@ -246,10 +252,10 @@ func TestOneOfTwoFailed(t *testing.T) { tokens = append(tokens, tok) return tok, nil }).AnyTimes() - mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.ResEndpointInfo, error) { + mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*sdkClient.ResEndpointInfo, error) { return nil, fmt.Errorf("error") }).AnyTimes() - mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.ResNetworkInfo, error) { + mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*sdkClient.ResNetworkInfo, error) { return nil, fmt.Errorf("error") }).AnyTimes() @@ -280,8 +286,9 @@ func TestOneOfTwoFailed(t *testing.T) { time.Sleep(2 * time.Second) for i := 0; i < 5; i++ { - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Equal(t, tokens[0], st) } } @@ -291,7 +298,7 @@ func TestTwoFailed(t *testing.T) { ctrl := gomock.NewController(t) - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes() @@ -318,7 +325,7 @@ func TestTwoFailed(t *testing.T) { time.Sleep(2 * time.Second) - _, _, err = pool.Connection() + _, err = pool.connection() require.Error(t, err) require.Contains(t, err.Error(), "no healthy") } @@ -329,7 +336,7 @@ func TestSessionCache(t *testing.T) { ctrl := gomock.NewController(t) var tokens []*session.Token - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { tok := session.NewToken() @@ -365,8 +372,9 @@ func TestSessionCache(t *testing.T) { t.Cleanup(pool.Close) // cache must contain session token - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Contains(t, tokens, st) var prm PrmObjectGet @@ -376,8 +384,9 @@ func TestSessionCache(t *testing.T) { require.Error(t, err) // cache must not contain session token - _, st, err = pool.Connection() + cp, err = pool.connection() require.NoError(t, err) + st = pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Nil(t, st) var prm2 PrmObjectPut @@ -387,8 +396,9 @@ func TestSessionCache(t *testing.T) { require.NoError(t, err) // cache must contain session token - _, st, err = pool.Connection() + cp, err = pool.connection() require.NoError(t, err) + st = pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Contains(t, tokens, st) } @@ -400,7 +410,7 @@ func TestPriority(t *testing.T) { var tokens []*session.Token clientCount := -1 - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { clientCount++ mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { @@ -446,13 +456,15 @@ func TestPriority(t *testing.T) { t.Cleanup(pool.Close) firstNode := func() bool { - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) return st == tokens[0] } secondNode := func() bool { - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) return st == tokens[1] } require.Never(t, secondNode, time.Second, 200*time.Millisecond) @@ -467,7 +479,7 @@ func TestSessionCacheWithKey(t *testing.T) { ctrl := gomock.NewController(t) var tokens []*session.Token - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { tok := session.NewToken() @@ -501,8 +513,9 @@ func TestSessionCacheWithKey(t *testing.T) { require.NoError(t, err) // cache must contain session token - _, st, err := pool.Connection() + cp, err := pool.connection() require.NoError(t, err) + st := pool.cache.Get(formCacheKey(cp.address, pool.key)) require.Contains(t, tokens, st) var prm PrmObjectGet @@ -525,11 +538,11 @@ func newToken(t *testing.T) *session.Token { func TestSessionTokenOwner(t *testing.T) { ctrl := gomock.NewController(t) - clientBuilder := func(_ string) (Client, error) { + clientBuilder := func(_ string) (client, error) { mockClient := NewMockClient(ctrl) - mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&client.ResSessionCreate{}, nil).AnyTimes() - mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&sdkClient.ResSessionCreate{}, nil).AnyTimes() + mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() return mockClient, nil } @@ -569,7 +582,7 @@ func TestWaitPresence(t *testing.T) { mockClient := NewMockClient(ctrl) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() + mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().GetContainer(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() cache, err := newCache() diff --git a/pool/sampler_test.go b/pool/sampler_test.go index 77a26d2..f41dbdb 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -6,7 +6,7 @@ import ( "math/rand" "testing" - "github.com/nspcc-dev/neofs-sdk-go/client" + sdkClient "github.com/nspcc-dev/neofs-sdk-go/client" "github.com/stretchr/testify/require" ) @@ -43,16 +43,16 @@ func TestSamplerStability(t *testing.T) { } type clientMock struct { - client.Client + sdkClient.Client name string err error } -func (c *clientMock) EndpointInfo(context.Context, client.PrmEndpointInfo) (*client.ResEndpointInfo, error) { +func (c *clientMock) EndpointInfo(context.Context, sdkClient.PrmEndpointInfo) (*sdkClient.ResEndpointInfo, error) { return nil, nil } -func (c *clientMock) NetworkInfo(context.Context, client.PrmNetworkInfo) (*client.ResNetworkInfo, error) { +func (c *clientMock) NetworkInfo(context.Context, sdkClient.PrmNetworkInfo) (*sdkClient.ResNetworkInfo, error) { return nil, nil } @@ -88,16 +88,16 @@ func TestHealthyReweight(t *testing.T) { } // check getting first node connection before rebalance happened - connection0, _, err := p.Connection() + connection0, err := p.connection() require.NoError(t, err) - mock0 := connection0.(*clientMock) + mock0 := connection0.client.(*clientMock) require.Equal(t, names[0], mock0.name) p.updateInnerNodesHealth(context.TODO(), 0, buffer) - connection1, _, err := p.Connection() + connection1, err := p.connection() require.NoError(t, err) - mock1 := connection1.(*clientMock) + mock1 := connection1.client.(*clientMock) require.Equal(t, names[1], mock1.name) // enabled first node again @@ -108,9 +108,9 @@ func TestHealthyReweight(t *testing.T) { p.updateInnerNodesHealth(context.TODO(), 0, buffer) inner.sampler = newSampler(weights, rand.NewSource(0)) - connection0, _, err = p.Connection() + connection0, err = p.connection() require.NoError(t, err) - mock0 = connection0.(*clientMock) + mock0 = connection0.client.(*clientMock) require.Equal(t, names[0], mock0.name) }