[#165] pool: make client interface unexported

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-03-15 15:00:38 +03:00 committed by Alex Vanin
parent 52548fe176
commit 9be9697856
4 changed files with 138 additions and 120 deletions

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/nspcc-dev/neofs-sdk-go/Pool (interfaces: Client)
// Source: github.com/nspcc-dev/neofs-sdk-go/Pool (interfaces: client)
// Package pool is a generated GoMock package.
package pool
@ -12,7 +12,7 @@ import (
client0 "github.com/nspcc-dev/neofs-sdk-go/client"
)
// MockClient is a mock of Client interface.
// MockClient is a mock of client interface.
type MockClient struct {
ctrl *gomock.Controller
recorder *MockClientMockRecorder

View file

@ -17,10 +17,11 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
sessionv2 "github.com/nspcc-dev/neofs-api-go/v2/session"
"github.com/nspcc-dev/neofs-sdk-go/accounting"
"github.com/nspcc-dev/neofs-sdk-go/client"
sdkClient "github.com/nspcc-dev/neofs-sdk-go/client"
"github.com/nspcc-dev/neofs-sdk-go/container"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
"github.com/nspcc-dev/neofs-sdk-go/eacl"
"github.com/nspcc-dev/neofs-sdk-go/netmap"
"github.com/nspcc-dev/neofs-sdk-go/object"
"github.com/nspcc-dev/neofs-sdk-go/object/address"
oid "github.com/nspcc-dev/neofs-sdk-go/object/id"
@ -30,27 +31,27 @@ import (
"go.uber.org/zap"
)
// Client is a wrapper for client.Client to generate mock.
type Client interface {
BalanceGet(context.Context, client.PrmBalanceGet) (*client.ResBalanceGet, error)
ContainerPut(context.Context, client.PrmContainerPut) (*client.ResContainerPut, error)
ContainerGet(context.Context, client.PrmContainerGet) (*client.ResContainerGet, error)
ContainerList(context.Context, client.PrmContainerList) (*client.ResContainerList, error)
ContainerDelete(context.Context, client.PrmContainerDelete) (*client.ResContainerDelete, error)
ContainerEACL(context.Context, client.PrmContainerEACL) (*client.ResContainerEACL, error)
ContainerSetEACL(context.Context, client.PrmContainerSetEACL) (*client.ResContainerSetEACL, error)
EndpointInfo(context.Context, client.PrmEndpointInfo) (*client.ResEndpointInfo, error)
NetworkInfo(context.Context, client.PrmNetworkInfo) (*client.ResNetworkInfo, error)
ObjectPutInit(context.Context, client.PrmObjectPutInit) (*client.ObjectWriter, error)
ObjectDelete(context.Context, client.PrmObjectDelete) (*client.ResObjectDelete, error)
ObjectGetInit(context.Context, client.PrmObjectGet) (*client.ObjectReader, error)
ObjectHead(context.Context, client.PrmObjectHead) (*client.ResObjectHead, error)
ObjectRangeInit(context.Context, client.PrmObjectRange) (*client.ObjectRangeReader, error)
ObjectSearchInit(context.Context, client.PrmObjectSearch) (*client.ObjectListReader, error)
SessionCreate(context.Context, client.PrmSessionCreate) (*client.ResSessionCreate, error)
// client is a wrapper for sdkClient.Client to generate mock.
type client interface {
BalanceGet(context.Context, sdkClient.PrmBalanceGet) (*sdkClient.ResBalanceGet, error)
ContainerPut(context.Context, sdkClient.PrmContainerPut) (*sdkClient.ResContainerPut, error)
ContainerGet(context.Context, sdkClient.PrmContainerGet) (*sdkClient.ResContainerGet, error)
ContainerList(context.Context, sdkClient.PrmContainerList) (*sdkClient.ResContainerList, error)
ContainerDelete(context.Context, sdkClient.PrmContainerDelete) (*sdkClient.ResContainerDelete, error)
ContainerEACL(context.Context, sdkClient.PrmContainerEACL) (*sdkClient.ResContainerEACL, error)
ContainerSetEACL(context.Context, sdkClient.PrmContainerSetEACL) (*sdkClient.ResContainerSetEACL, error)
EndpointInfo(context.Context, sdkClient.PrmEndpointInfo) (*sdkClient.ResEndpointInfo, error)
NetworkInfo(context.Context, sdkClient.PrmNetworkInfo) (*sdkClient.ResNetworkInfo, error)
ObjectPutInit(context.Context, sdkClient.PrmObjectPutInit) (*sdkClient.ObjectWriter, error)
ObjectDelete(context.Context, sdkClient.PrmObjectDelete) (*sdkClient.ResObjectDelete, error)
ObjectGetInit(context.Context, sdkClient.PrmObjectGet) (*sdkClient.ObjectReader, error)
ObjectHead(context.Context, sdkClient.PrmObjectHead) (*sdkClient.ResObjectHead, error)
ObjectRangeInit(context.Context, sdkClient.PrmObjectRange) (*sdkClient.ObjectRangeReader, error)
ObjectSearchInit(context.Context, sdkClient.PrmObjectSearch) (*sdkClient.ObjectListReader, error)
SessionCreate(context.Context, sdkClient.PrmSessionCreate) (*sdkClient.ResSessionCreate, error)
}
// InitParameters contains options used to create connection Pool.
// InitParameters contains values used to initialize connection Pool.
type InitParameters struct {
Key *ecdsa.PrivateKey
Logger *zap.Logger
@ -61,7 +62,7 @@ type InitParameters struct {
SessionExpirationDuration uint64
NodeParams []NodeParam
clientBuilder func(endpoint string) (Client, error)
clientBuilder func(endpoint string) (client, error)
}
type rebalanceParameters struct {
@ -99,7 +100,7 @@ func DefaultPollingParams() *ContainerPollingParams {
}
type clientPack struct {
client Client
client client
healthy bool
address string
}
@ -325,7 +326,7 @@ func (x *PrmBalanceGet) SetOwnerID(ownerID *owner.ID) {
// Pool represents virtual connection to the NeoFS network to communicate
// with multiple NeoFS servers without thinking about switching between servers
// due to load balancing proportions or their unavailability.
// It is designed to provide a convenient abstraction from the multiple client.Client types.
// It is designed to provide a convenient abstraction from the multiple sdkClient.client types.
//
// Pool can be created and initialized using NewPool function.
// Before executing the NeoFS operations using the Pool, connection to the
@ -338,9 +339,9 @@ func (x *PrmBalanceGet) SetOwnerID(ownerID *owner.ID) {
//
// Each method which produces a NeoFS API call may return an error.
// Status of underlying server response is casted to built-in error instance.
// Certain statuses can be checked using `client` and standard `errors` packages.
// Certain statuses can be checked using `sdkClient` and standard `errors` packages.
// Note that package provides some helper functions to work with status returns
// (e.g. client.IsErrContainerNotFound, client.IsErrObjectNotFound).
// (e.g. sdkClient.IsErrContainerNotFound, sdkClient.IsErrObjectNotFound).
//
// See pool package overview to get some examples.
type Pool struct {
@ -353,7 +354,7 @@ type Pool struct {
stokenDuration uint64
stokenThreshold time.Duration
rebalanceParams rebalanceParameters
clientBuilder func(endpoint string) (Client, error)
clientBuilder func(endpoint string) (client, error)
logger *zap.Logger
}
@ -475,16 +476,16 @@ func fillDefaultInitParams(params *InitParameters) {
}
if params.clientBuilder == nil {
params.clientBuilder = func(addr string) (Client, error) {
var c client.Client
params.clientBuilder = func(addr string) (client, error) {
var c sdkClient.Client
var prmInit client.PrmInit
var prmInit sdkClient.PrmInit
prmInit.ResolveNeoFSFailures()
prmInit.SetDefaultPrivateKey(*params.Key)
c.Init(prmInit)
var prmDial client.PrmDial
var prmDial sdkClient.PrmDial
prmDial.SetServerURI(addr)
prmDial.SetTimeout(params.NodeConnectionTimeout)
@ -565,11 +566,11 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights
healthyChanged := false
wg := sync.WaitGroup{}
var prmEndpoint client.PrmEndpointInfo
var prmEndpoint sdkClient.PrmEndpointInfo
for j, cPack := range pool.clientPacks {
wg.Add(1)
go func(j int, cli Client) {
go func(j int, cli client) {
defer wg.Done()
ok := true
tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout)
@ -634,16 +635,6 @@ func adjustWeights(weights []float64) []float64 {
return adjusted
}
func (p *Pool) Connection() (Client, *session.Token, error) {
cp, err := p.connection()
if err != nil {
return nil, nil, err
}
tok := p.cache.Get(formCacheKey(cp.address, p.key))
return cp.client, tok, nil
}
func (p *Pool) connection() (*clientPack, error) {
for _, inner := range p.innerPools {
cp, err := inner.connection()
@ -732,15 +723,15 @@ func (p *Pool) checkSessionTokenErr(err error, address string) bool {
return false
}
func createSessionTokenForDuration(ctx context.Context, c Client, dur uint64) (*client.ResSessionCreate, error) {
ni, err := c.NetworkInfo(ctx, client.PrmNetworkInfo{})
func createSessionTokenForDuration(ctx context.Context, c client, dur uint64) (*sdkClient.ResSessionCreate, error) {
ni, err := c.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{})
if err != nil {
return nil, err
}
epoch := ni.Info().CurrentEpoch()
var prm client.PrmSessionCreate
var prm sdkClient.PrmSessionCreate
if math.MaxUint64-epoch < dur {
prm.SetExp(math.MaxUint64)
} else {
@ -773,7 +764,7 @@ type callContext struct {
// base context for RPC
context.Context
client Client
client client
// client endpoint
endpoint string
@ -921,7 +912,7 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (*oid.ID, error)
return nil, fmt.Errorf("init call context")
}
var cliPrm client.PrmObjectPutInit
var cliPrm sdkClient.PrmObjectPutInit
wObj, err := ctxCall.client.ObjectPutInit(ctx, cliPrm)
if err != nil {
@ -1009,7 +1000,7 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error {
prm.useVerb(sessionv2.ObjectVerbDelete)
prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectDelete
var cliPrm sdkClient.PrmObjectDelete
var cc callContextWithRetry
@ -1041,16 +1032,16 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error {
})
}
type objectReadCloser client.ObjectReader
type objectReadCloser sdkClient.ObjectReader
// Read implements io.Reader of the object payload.
func (x *objectReadCloser) Read(p []byte) (int, error) {
return (*client.ObjectReader)(x).Read(p)
return (*sdkClient.ObjectReader)(x).Read(p)
}
// Close implements io.Closer of the object payload.
func (x *objectReadCloser) Close() error {
_, err := (*client.ObjectReader)(x).Close()
_, err := (*sdkClient.ObjectReader)(x).Close()
return err
}
@ -1067,7 +1058,7 @@ func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (*ResGetObject,
prm.useVerb(sessionv2.ObjectVerbGet)
prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectGet
var cliPrm sdkClient.PrmObjectGet
var cc callContextWithRetry
@ -1119,7 +1110,7 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (*object.Objec
prm.useVerb(sessionv2.ObjectVerbHead)
prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectHead
var cliPrm sdkClient.PrmObjectHead
var cc callContextWithRetry
@ -1168,7 +1159,7 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (*object.Objec
// Must be initialized using Pool.ObjectRange, any other
// usage is unsafe.
type ResObjectRange struct {
payload *client.ObjectRangeReader
payload *sdkClient.ObjectRangeReader
}
// Read implements io.Reader of the object payload.
@ -1190,7 +1181,7 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (*ResObjectR
prm.useVerb(sessionv2.ObjectVerbRange)
prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectRange
var cliPrm sdkClient.PrmObjectRange
cliPrm.SetOffset(prm.off)
cliPrm.SetLength(prm.ln)
@ -1238,7 +1229,7 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (*ResObjectR
//
// Must be initialized using Pool.SearchObjects, any other usage is unsafe.
type ResObjectSearch struct {
r *client.ObjectListReader
r *sdkClient.ObjectListReader
}
// Read reads another list of the object identifiers.
@ -1280,7 +1271,7 @@ func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (*ResObje
prm.useVerb(sessionv2.ObjectVerbSearch)
prm.useAddress(newAddressFromCnrID(&prm.cnrID))
var cliPrm client.PrmObjectSearch
var cliPrm sdkClient.PrmObjectSearch
cliPrm.InContainer(prm.cnrID)
cliPrm.SetFilters(prm.filters)
@ -1328,7 +1319,7 @@ func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (*cid.ID,
return nil, err
}
var cliPrm client.PrmContainerPut
var cliPrm sdkClient.PrmContainerPut
if prm.cnr != nil {
cliPrm.SetContainer(*prm.cnr)
@ -1349,7 +1340,7 @@ func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (*containe
return nil, err
}
var cliPrm client.PrmContainerGet
var cliPrm sdkClient.PrmContainerGet
if prm.cnrID != nil {
cliPrm.SetContainer(*prm.cnrID)
@ -1370,7 +1361,7 @@ func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid.
return nil, err
}
var cliPrm client.PrmContainerList
var cliPrm sdkClient.PrmContainerList
if prm.ownerID != nil {
cliPrm.SetAccount(*prm.ownerID)
@ -1396,7 +1387,7 @@ func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) erro
return err
}
var cliPrm client.PrmContainerDelete
var cliPrm sdkClient.PrmContainerDelete
if prm.cnrID != nil {
cliPrm.SetContainer(*prm.cnrID)
@ -1420,7 +1411,7 @@ func (p *Pool) GetEACL(ctx context.Context, prm PrmContainerEACL) (*eacl.Table,
return nil, err
}
var cliPrm client.PrmContainerEACL
var cliPrm sdkClient.PrmContainerEACL
if prm.cnrID != nil {
cliPrm.SetContainer(*prm.cnrID)
@ -1446,7 +1437,7 @@ func (p *Pool) SetEACL(ctx context.Context, prm PrmContainerSetEACL) error {
return err
}
var cliPrm client.PrmContainerSetEACL
var cliPrm sdkClient.PrmContainerSetEACL
if prm.table != nil {
cliPrm.SetTable(*prm.table)
@ -1466,7 +1457,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (*accounting.Deci
return nil, err
}
var cliPrm client.PrmBalanceGet
var cliPrm sdkClient.PrmBalanceGet
if prm.ownerID != nil {
cliPrm.SetAccount(*prm.ownerID)
@ -1482,7 +1473,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (*accounting.Deci
// WaitForContainerPresence waits until the container is found on the NeoFS network.
func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollParams *ContainerPollingParams) error {
conn, _, err := p.Connection()
cp, err := p.connection()
if err != nil {
return err
}
@ -1493,7 +1484,7 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa
wdone := wctx.Done()
done := ctx.Done()
var cliPrm client.PrmContainerGet
var cliPrm sdkClient.PrmContainerGet
if cid != nil {
cliPrm.SetContainer(*cid)
@ -1506,7 +1497,7 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa
case <-wdone:
return wctx.Err()
case <-ticker.C:
_, err = conn.ContainerGet(ctx, cliPrm)
_, err = cp.client.ContainerGet(ctx, cliPrm)
if err == nil {
return nil
}
@ -1515,6 +1506,20 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa
}
}
func (p *Pool) NetworkInfo(ctx context.Context) (*netmap.NetworkInfo, error) {
cp, err := p.connection()
if err != nil {
return nil, err
}
res, err := cp.client.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{})
if err != nil { // here err already carries both status and client errors
return nil, err
}
return res.Info(), nil
}
// Close closes the Pool and releases all the associated resources.
func (p *Pool) Close() {
p.cancel()
@ -1522,12 +1527,12 @@ func (p *Pool) Close() {
}
// creates new session token from SessionCreate call result.
func (p *Pool) newSessionToken(cliRes *client.ResSessionCreate) *session.Token {
func (p *Pool) newSessionToken(cliRes *sdkClient.ResSessionCreate) *session.Token {
return sessionTokenForOwner(p.owner, cliRes)
}
// creates new session token with specified owner from SessionCreate call result.
func sessionTokenForOwner(id *owner.ID, cliRes *client.ResSessionCreate) *session.Token {
func sessionTokenForOwner(id *owner.ID, cliRes *sdkClient.ResSessionCreate) *session.Token {
st := session.NewToken()
st.SetOwnerID(id)
st.SetID(cliRes.ID())

View file

@ -13,7 +13,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neofs-sdk-go/client"
sdkClient "github.com/nspcc-dev/neofs-sdk-go/client"
"github.com/nspcc-dev/neofs-sdk-go/netmap"
"github.com/nspcc-dev/neofs-sdk-go/object"
"github.com/nspcc-dev/neofs-sdk-go/object/address"
@ -24,7 +24,7 @@ import (
)
func TestBuildPoolClientFailed(t *testing.T) {
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
return nil, fmt.Errorf("error")
}
@ -46,11 +46,11 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) {
ni := &netmap.NodeInfo{}
ni.SetAddresses("addr1", "addr2")
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error session")).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil
}
@ -83,7 +83,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
var expectedToken *session.Token
clientCount := -1
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
clientCount++
mockClient := NewMockClient(ctrl)
mockInvokes := 0
@ -97,12 +97,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
}).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
mockClient2 := NewMockClient(ctrl2)
mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
if clientCount == 0 {
return mockClient, nil
@ -130,8 +130,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
t.Cleanup(clientPool.Close)
condition := func() bool {
_, st, err := clientPool.Connection()
return err == nil && st == expectedToken
cp, err := clientPool.connection()
if err != nil {
return false
}
st := clientPool.cache.Get(formCacheKey(cp.address, clientPool.key))
return st == expectedToken
}
require.Never(t, condition, 900*time.Millisecond, 100*time.Millisecond)
require.Eventually(t, condition, 3*time.Second, 300*time.Millisecond)
@ -155,11 +159,11 @@ func TestOneNode(t *testing.T) {
require.NoError(t, err)
tok.SetID(uid)
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(tok, nil)
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil
}
@ -175,8 +179,9 @@ func TestOneNode(t *testing.T) {
require.NoError(t, err)
t.Cleanup(pool.Close)
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Equal(t, tok, st)
}
@ -186,7 +191,7 @@ func TestTwoNodes(t *testing.T) {
ctrl := gomock.NewController(t)
var tokens []*session.Token
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
tok := session.NewToken()
@ -196,8 +201,8 @@ func TestTwoNodes(t *testing.T) {
tokens = append(tokens, tok)
return tok, err
})
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil
}
@ -216,8 +221,9 @@ func TestTwoNodes(t *testing.T) {
require.NoError(t, err)
t.Cleanup(pool.Close)
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st)
}
@ -229,7 +235,7 @@ func TestOneOfTwoFailed(t *testing.T) {
var tokens []*session.Token
clientCount := -1
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
clientCount++
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
@ -238,7 +244,7 @@ func TestOneOfTwoFailed(t *testing.T) {
return tok, nil
}).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
mockClient2 := NewMockClient(ctrl2)
mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
@ -246,10 +252,10 @@ func TestOneOfTwoFailed(t *testing.T) {
tokens = append(tokens, tok)
return tok, nil
}).AnyTimes()
mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.ResEndpointInfo, error) {
mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*sdkClient.ResEndpointInfo, error) {
return nil, fmt.Errorf("error")
}).AnyTimes()
mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.ResNetworkInfo, error) {
mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*sdkClient.ResNetworkInfo, error) {
return nil, fmt.Errorf("error")
}).AnyTimes()
@ -280,8 +286,9 @@ func TestOneOfTwoFailed(t *testing.T) {
time.Sleep(2 * time.Second)
for i := 0; i < 5; i++ {
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Equal(t, tokens[0], st)
}
}
@ -291,7 +298,7 @@ func TestTwoFailed(t *testing.T) {
ctrl := gomock.NewController(t)
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes()
@ -318,7 +325,7 @@ func TestTwoFailed(t *testing.T) {
time.Sleep(2 * time.Second)
_, _, err = pool.Connection()
_, err = pool.connection()
require.Error(t, err)
require.Contains(t, err.Error(), "no healthy")
}
@ -329,7 +336,7 @@ func TestSessionCache(t *testing.T) {
ctrl := gomock.NewController(t)
var tokens []*session.Token
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
tok := session.NewToken()
@ -365,8 +372,9 @@ func TestSessionCache(t *testing.T) {
t.Cleanup(pool.Close)
// cache must contain session token
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st)
var prm PrmObjectGet
@ -376,8 +384,9 @@ func TestSessionCache(t *testing.T) {
require.Error(t, err)
// cache must not contain session token
_, st, err = pool.Connection()
cp, err = pool.connection()
require.NoError(t, err)
st = pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Nil(t, st)
var prm2 PrmObjectPut
@ -387,8 +396,9 @@ func TestSessionCache(t *testing.T) {
require.NoError(t, err)
// cache must contain session token
_, st, err = pool.Connection()
cp, err = pool.connection()
require.NoError(t, err)
st = pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st)
}
@ -400,7 +410,7 @@ func TestPriority(t *testing.T) {
var tokens []*session.Token
clientCount := -1
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
clientCount++
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
@ -446,13 +456,15 @@ func TestPriority(t *testing.T) {
t.Cleanup(pool.Close)
firstNode := func() bool {
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
return st == tokens[0]
}
secondNode := func() bool {
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
return st == tokens[1]
}
require.Never(t, secondNode, time.Second, 200*time.Millisecond)
@ -467,7 +479,7 @@ func TestSessionCacheWithKey(t *testing.T) {
ctrl := gomock.NewController(t)
var tokens []*session.Token
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
tok := session.NewToken()
@ -501,8 +513,9 @@ func TestSessionCacheWithKey(t *testing.T) {
require.NoError(t, err)
// cache must contain session token
_, st, err := pool.Connection()
cp, err := pool.connection()
require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st)
var prm PrmObjectGet
@ -525,11 +538,11 @@ func newToken(t *testing.T) *session.Token {
func TestSessionTokenOwner(t *testing.T) {
ctrl := gomock.NewController(t)
clientBuilder := func(_ string) (Client, error) {
clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&client.ResSessionCreate{}, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&sdkClient.ResSessionCreate{}, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil
}
@ -569,7 +582,7 @@ func TestWaitPresence(t *testing.T) {
mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().GetContainer(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
cache, err := newCache()

View file

@ -6,7 +6,7 @@ import (
"math/rand"
"testing"
"github.com/nspcc-dev/neofs-sdk-go/client"
sdkClient "github.com/nspcc-dev/neofs-sdk-go/client"
"github.com/stretchr/testify/require"
)
@ -43,16 +43,16 @@ func TestSamplerStability(t *testing.T) {
}
type clientMock struct {
client.Client
sdkClient.Client
name string
err error
}
func (c *clientMock) EndpointInfo(context.Context, client.PrmEndpointInfo) (*client.ResEndpointInfo, error) {
func (c *clientMock) EndpointInfo(context.Context, sdkClient.PrmEndpointInfo) (*sdkClient.ResEndpointInfo, error) {
return nil, nil
}
func (c *clientMock) NetworkInfo(context.Context, client.PrmNetworkInfo) (*client.ResNetworkInfo, error) {
func (c *clientMock) NetworkInfo(context.Context, sdkClient.PrmNetworkInfo) (*sdkClient.ResNetworkInfo, error) {
return nil, nil
}
@ -88,16 +88,16 @@ func TestHealthyReweight(t *testing.T) {
}
// check getting first node connection before rebalance happened
connection0, _, err := p.Connection()
connection0, err := p.connection()
require.NoError(t, err)
mock0 := connection0.(*clientMock)
mock0 := connection0.client.(*clientMock)
require.Equal(t, names[0], mock0.name)
p.updateInnerNodesHealth(context.TODO(), 0, buffer)
connection1, _, err := p.Connection()
connection1, err := p.connection()
require.NoError(t, err)
mock1 := connection1.(*clientMock)
mock1 := connection1.client.(*clientMock)
require.Equal(t, names[1], mock1.name)
// enabled first node again
@ -108,9 +108,9 @@ func TestHealthyReweight(t *testing.T) {
p.updateInnerNodesHealth(context.TODO(), 0, buffer)
inner.sampler = newSampler(weights, rand.NewSource(0))
connection0, _, err = p.Connection()
connection0, err = p.connection()
require.NoError(t, err)
mock0 = connection0.(*clientMock)
mock0 = connection0.client.(*clientMock)
require.Equal(t, names[0], mock0.name)
}