[#165] pool: make client interface unexported

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-03-15 15:00:38 +03:00 committed by Alex Vanin
parent 52548fe176
commit 9be9697856
4 changed files with 138 additions and 120 deletions

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/nspcc-dev/neofs-sdk-go/Pool (interfaces: Client) // Source: github.com/nspcc-dev/neofs-sdk-go/Pool (interfaces: client)
// Package pool is a generated GoMock package. // Package pool is a generated GoMock package.
package pool package pool
@ -12,7 +12,7 @@ import (
client0 "github.com/nspcc-dev/neofs-sdk-go/client" client0 "github.com/nspcc-dev/neofs-sdk-go/client"
) )
// MockClient is a mock of Client interface. // MockClient is a mock of client interface.
type MockClient struct { type MockClient struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockClientMockRecorder recorder *MockClientMockRecorder

View file

@ -17,10 +17,11 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
sessionv2 "github.com/nspcc-dev/neofs-api-go/v2/session" sessionv2 "github.com/nspcc-dev/neofs-api-go/v2/session"
"github.com/nspcc-dev/neofs-sdk-go/accounting" "github.com/nspcc-dev/neofs-sdk-go/accounting"
"github.com/nspcc-dev/neofs-sdk-go/client" sdkClient "github.com/nspcc-dev/neofs-sdk-go/client"
"github.com/nspcc-dev/neofs-sdk-go/container" "github.com/nspcc-dev/neofs-sdk-go/container"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id" cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
"github.com/nspcc-dev/neofs-sdk-go/eacl" "github.com/nspcc-dev/neofs-sdk-go/eacl"
"github.com/nspcc-dev/neofs-sdk-go/netmap"
"github.com/nspcc-dev/neofs-sdk-go/object" "github.com/nspcc-dev/neofs-sdk-go/object"
"github.com/nspcc-dev/neofs-sdk-go/object/address" "github.com/nspcc-dev/neofs-sdk-go/object/address"
oid "github.com/nspcc-dev/neofs-sdk-go/object/id" oid "github.com/nspcc-dev/neofs-sdk-go/object/id"
@ -30,27 +31,27 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// Client is a wrapper for client.Client to generate mock. // client is a wrapper for sdkClient.Client to generate mock.
type Client interface { type client interface {
BalanceGet(context.Context, client.PrmBalanceGet) (*client.ResBalanceGet, error) BalanceGet(context.Context, sdkClient.PrmBalanceGet) (*sdkClient.ResBalanceGet, error)
ContainerPut(context.Context, client.PrmContainerPut) (*client.ResContainerPut, error) ContainerPut(context.Context, sdkClient.PrmContainerPut) (*sdkClient.ResContainerPut, error)
ContainerGet(context.Context, client.PrmContainerGet) (*client.ResContainerGet, error) ContainerGet(context.Context, sdkClient.PrmContainerGet) (*sdkClient.ResContainerGet, error)
ContainerList(context.Context, client.PrmContainerList) (*client.ResContainerList, error) ContainerList(context.Context, sdkClient.PrmContainerList) (*sdkClient.ResContainerList, error)
ContainerDelete(context.Context, client.PrmContainerDelete) (*client.ResContainerDelete, error) ContainerDelete(context.Context, sdkClient.PrmContainerDelete) (*sdkClient.ResContainerDelete, error)
ContainerEACL(context.Context, client.PrmContainerEACL) (*client.ResContainerEACL, error) ContainerEACL(context.Context, sdkClient.PrmContainerEACL) (*sdkClient.ResContainerEACL, error)
ContainerSetEACL(context.Context, client.PrmContainerSetEACL) (*client.ResContainerSetEACL, error) ContainerSetEACL(context.Context, sdkClient.PrmContainerSetEACL) (*sdkClient.ResContainerSetEACL, error)
EndpointInfo(context.Context, client.PrmEndpointInfo) (*client.ResEndpointInfo, error) EndpointInfo(context.Context, sdkClient.PrmEndpointInfo) (*sdkClient.ResEndpointInfo, error)
NetworkInfo(context.Context, client.PrmNetworkInfo) (*client.ResNetworkInfo, error) NetworkInfo(context.Context, sdkClient.PrmNetworkInfo) (*sdkClient.ResNetworkInfo, error)
ObjectPutInit(context.Context, client.PrmObjectPutInit) (*client.ObjectWriter, error) ObjectPutInit(context.Context, sdkClient.PrmObjectPutInit) (*sdkClient.ObjectWriter, error)
ObjectDelete(context.Context, client.PrmObjectDelete) (*client.ResObjectDelete, error) ObjectDelete(context.Context, sdkClient.PrmObjectDelete) (*sdkClient.ResObjectDelete, error)
ObjectGetInit(context.Context, client.PrmObjectGet) (*client.ObjectReader, error) ObjectGetInit(context.Context, sdkClient.PrmObjectGet) (*sdkClient.ObjectReader, error)
ObjectHead(context.Context, client.PrmObjectHead) (*client.ResObjectHead, error) ObjectHead(context.Context, sdkClient.PrmObjectHead) (*sdkClient.ResObjectHead, error)
ObjectRangeInit(context.Context, client.PrmObjectRange) (*client.ObjectRangeReader, error) ObjectRangeInit(context.Context, sdkClient.PrmObjectRange) (*sdkClient.ObjectRangeReader, error)
ObjectSearchInit(context.Context, client.PrmObjectSearch) (*client.ObjectListReader, error) ObjectSearchInit(context.Context, sdkClient.PrmObjectSearch) (*sdkClient.ObjectListReader, error)
SessionCreate(context.Context, client.PrmSessionCreate) (*client.ResSessionCreate, error) SessionCreate(context.Context, sdkClient.PrmSessionCreate) (*sdkClient.ResSessionCreate, error)
} }
// InitParameters contains options used to create connection Pool. // InitParameters contains values used to initialize connection Pool.
type InitParameters struct { type InitParameters struct {
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
Logger *zap.Logger Logger *zap.Logger
@ -61,7 +62,7 @@ type InitParameters struct {
SessionExpirationDuration uint64 SessionExpirationDuration uint64
NodeParams []NodeParam NodeParams []NodeParam
clientBuilder func(endpoint string) (Client, error) clientBuilder func(endpoint string) (client, error)
} }
type rebalanceParameters struct { type rebalanceParameters struct {
@ -99,7 +100,7 @@ func DefaultPollingParams() *ContainerPollingParams {
} }
type clientPack struct { type clientPack struct {
client Client client client
healthy bool healthy bool
address string address string
} }
@ -325,7 +326,7 @@ func (x *PrmBalanceGet) SetOwnerID(ownerID *owner.ID) {
// Pool represents virtual connection to the NeoFS network to communicate // Pool represents virtual connection to the NeoFS network to communicate
// with multiple NeoFS servers without thinking about switching between servers // with multiple NeoFS servers without thinking about switching between servers
// due to load balancing proportions or their unavailability. // due to load balancing proportions or their unavailability.
// It is designed to provide a convenient abstraction from the multiple client.Client types. // It is designed to provide a convenient abstraction from the multiple sdkClient.client types.
// //
// Pool can be created and initialized using NewPool function. // Pool can be created and initialized using NewPool function.
// Before executing the NeoFS operations using the Pool, connection to the // Before executing the NeoFS operations using the Pool, connection to the
@ -338,9 +339,9 @@ func (x *PrmBalanceGet) SetOwnerID(ownerID *owner.ID) {
// //
// Each method which produces a NeoFS API call may return an error. // Each method which produces a NeoFS API call may return an error.
// Status of underlying server response is casted to built-in error instance. // Status of underlying server response is casted to built-in error instance.
// Certain statuses can be checked using `client` and standard `errors` packages. // Certain statuses can be checked using `sdkClient` and standard `errors` packages.
// Note that package provides some helper functions to work with status returns // Note that package provides some helper functions to work with status returns
// (e.g. client.IsErrContainerNotFound, client.IsErrObjectNotFound). // (e.g. sdkClient.IsErrContainerNotFound, sdkClient.IsErrObjectNotFound).
// //
// See pool package overview to get some examples. // See pool package overview to get some examples.
type Pool struct { type Pool struct {
@ -353,7 +354,7 @@ type Pool struct {
stokenDuration uint64 stokenDuration uint64
stokenThreshold time.Duration stokenThreshold time.Duration
rebalanceParams rebalanceParameters rebalanceParams rebalanceParameters
clientBuilder func(endpoint string) (Client, error) clientBuilder func(endpoint string) (client, error)
logger *zap.Logger logger *zap.Logger
} }
@ -475,16 +476,16 @@ func fillDefaultInitParams(params *InitParameters) {
} }
if params.clientBuilder == nil { if params.clientBuilder == nil {
params.clientBuilder = func(addr string) (Client, error) { params.clientBuilder = func(addr string) (client, error) {
var c client.Client var c sdkClient.Client
var prmInit client.PrmInit var prmInit sdkClient.PrmInit
prmInit.ResolveNeoFSFailures() prmInit.ResolveNeoFSFailures()
prmInit.SetDefaultPrivateKey(*params.Key) prmInit.SetDefaultPrivateKey(*params.Key)
c.Init(prmInit) c.Init(prmInit)
var prmDial client.PrmDial var prmDial sdkClient.PrmDial
prmDial.SetServerURI(addr) prmDial.SetServerURI(addr)
prmDial.SetTimeout(params.NodeConnectionTimeout) prmDial.SetTimeout(params.NodeConnectionTimeout)
@ -565,11 +566,11 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights
healthyChanged := false healthyChanged := false
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
var prmEndpoint client.PrmEndpointInfo var prmEndpoint sdkClient.PrmEndpointInfo
for j, cPack := range pool.clientPacks { for j, cPack := range pool.clientPacks {
wg.Add(1) wg.Add(1)
go func(j int, cli Client) { go func(j int, cli client) {
defer wg.Done() defer wg.Done()
ok := true ok := true
tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout)
@ -634,16 +635,6 @@ func adjustWeights(weights []float64) []float64 {
return adjusted return adjusted
} }
func (p *Pool) Connection() (Client, *session.Token, error) {
cp, err := p.connection()
if err != nil {
return nil, nil, err
}
tok := p.cache.Get(formCacheKey(cp.address, p.key))
return cp.client, tok, nil
}
func (p *Pool) connection() (*clientPack, error) { func (p *Pool) connection() (*clientPack, error) {
for _, inner := range p.innerPools { for _, inner := range p.innerPools {
cp, err := inner.connection() cp, err := inner.connection()
@ -732,15 +723,15 @@ func (p *Pool) checkSessionTokenErr(err error, address string) bool {
return false return false
} }
func createSessionTokenForDuration(ctx context.Context, c Client, dur uint64) (*client.ResSessionCreate, error) { func createSessionTokenForDuration(ctx context.Context, c client, dur uint64) (*sdkClient.ResSessionCreate, error) {
ni, err := c.NetworkInfo(ctx, client.PrmNetworkInfo{}) ni, err := c.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
epoch := ni.Info().CurrentEpoch() epoch := ni.Info().CurrentEpoch()
var prm client.PrmSessionCreate var prm sdkClient.PrmSessionCreate
if math.MaxUint64-epoch < dur { if math.MaxUint64-epoch < dur {
prm.SetExp(math.MaxUint64) prm.SetExp(math.MaxUint64)
} else { } else {
@ -773,7 +764,7 @@ type callContext struct {
// base context for RPC // base context for RPC
context.Context context.Context
client Client client client
// client endpoint // client endpoint
endpoint string endpoint string
@ -921,7 +912,7 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (*oid.ID, error)
return nil, fmt.Errorf("init call context") return nil, fmt.Errorf("init call context")
} }
var cliPrm client.PrmObjectPutInit var cliPrm sdkClient.PrmObjectPutInit
wObj, err := ctxCall.client.ObjectPutInit(ctx, cliPrm) wObj, err := ctxCall.client.ObjectPutInit(ctx, cliPrm)
if err != nil { if err != nil {
@ -1009,7 +1000,7 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error {
prm.useVerb(sessionv2.ObjectVerbDelete) prm.useVerb(sessionv2.ObjectVerbDelete)
prm.useAddress(&prm.addr) prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectDelete var cliPrm sdkClient.PrmObjectDelete
var cc callContextWithRetry var cc callContextWithRetry
@ -1041,16 +1032,16 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error {
}) })
} }
type objectReadCloser client.ObjectReader type objectReadCloser sdkClient.ObjectReader
// Read implements io.Reader of the object payload. // Read implements io.Reader of the object payload.
func (x *objectReadCloser) Read(p []byte) (int, error) { func (x *objectReadCloser) Read(p []byte) (int, error) {
return (*client.ObjectReader)(x).Read(p) return (*sdkClient.ObjectReader)(x).Read(p)
} }
// Close implements io.Closer of the object payload. // Close implements io.Closer of the object payload.
func (x *objectReadCloser) Close() error { func (x *objectReadCloser) Close() error {
_, err := (*client.ObjectReader)(x).Close() _, err := (*sdkClient.ObjectReader)(x).Close()
return err return err
} }
@ -1067,7 +1058,7 @@ func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (*ResGetObject,
prm.useVerb(sessionv2.ObjectVerbGet) prm.useVerb(sessionv2.ObjectVerbGet)
prm.useAddress(&prm.addr) prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectGet var cliPrm sdkClient.PrmObjectGet
var cc callContextWithRetry var cc callContextWithRetry
@ -1119,7 +1110,7 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (*object.Objec
prm.useVerb(sessionv2.ObjectVerbHead) prm.useVerb(sessionv2.ObjectVerbHead)
prm.useAddress(&prm.addr) prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectHead var cliPrm sdkClient.PrmObjectHead
var cc callContextWithRetry var cc callContextWithRetry
@ -1168,7 +1159,7 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (*object.Objec
// Must be initialized using Pool.ObjectRange, any other // Must be initialized using Pool.ObjectRange, any other
// usage is unsafe. // usage is unsafe.
type ResObjectRange struct { type ResObjectRange struct {
payload *client.ObjectRangeReader payload *sdkClient.ObjectRangeReader
} }
// Read implements io.Reader of the object payload. // Read implements io.Reader of the object payload.
@ -1190,7 +1181,7 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (*ResObjectR
prm.useVerb(sessionv2.ObjectVerbRange) prm.useVerb(sessionv2.ObjectVerbRange)
prm.useAddress(&prm.addr) prm.useAddress(&prm.addr)
var cliPrm client.PrmObjectRange var cliPrm sdkClient.PrmObjectRange
cliPrm.SetOffset(prm.off) cliPrm.SetOffset(prm.off)
cliPrm.SetLength(prm.ln) cliPrm.SetLength(prm.ln)
@ -1238,7 +1229,7 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (*ResObjectR
// //
// Must be initialized using Pool.SearchObjects, any other usage is unsafe. // Must be initialized using Pool.SearchObjects, any other usage is unsafe.
type ResObjectSearch struct { type ResObjectSearch struct {
r *client.ObjectListReader r *sdkClient.ObjectListReader
} }
// Read reads another list of the object identifiers. // Read reads another list of the object identifiers.
@ -1280,7 +1271,7 @@ func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (*ResObje
prm.useVerb(sessionv2.ObjectVerbSearch) prm.useVerb(sessionv2.ObjectVerbSearch)
prm.useAddress(newAddressFromCnrID(&prm.cnrID)) prm.useAddress(newAddressFromCnrID(&prm.cnrID))
var cliPrm client.PrmObjectSearch var cliPrm sdkClient.PrmObjectSearch
cliPrm.InContainer(prm.cnrID) cliPrm.InContainer(prm.cnrID)
cliPrm.SetFilters(prm.filters) cliPrm.SetFilters(prm.filters)
@ -1328,7 +1319,7 @@ func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (*cid.ID,
return nil, err return nil, err
} }
var cliPrm client.PrmContainerPut var cliPrm sdkClient.PrmContainerPut
if prm.cnr != nil { if prm.cnr != nil {
cliPrm.SetContainer(*prm.cnr) cliPrm.SetContainer(*prm.cnr)
@ -1349,7 +1340,7 @@ func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (*containe
return nil, err return nil, err
} }
var cliPrm client.PrmContainerGet var cliPrm sdkClient.PrmContainerGet
if prm.cnrID != nil { if prm.cnrID != nil {
cliPrm.SetContainer(*prm.cnrID) cliPrm.SetContainer(*prm.cnrID)
@ -1370,7 +1361,7 @@ func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid.
return nil, err return nil, err
} }
var cliPrm client.PrmContainerList var cliPrm sdkClient.PrmContainerList
if prm.ownerID != nil { if prm.ownerID != nil {
cliPrm.SetAccount(*prm.ownerID) cliPrm.SetAccount(*prm.ownerID)
@ -1396,7 +1387,7 @@ func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) erro
return err return err
} }
var cliPrm client.PrmContainerDelete var cliPrm sdkClient.PrmContainerDelete
if prm.cnrID != nil { if prm.cnrID != nil {
cliPrm.SetContainer(*prm.cnrID) cliPrm.SetContainer(*prm.cnrID)
@ -1420,7 +1411,7 @@ func (p *Pool) GetEACL(ctx context.Context, prm PrmContainerEACL) (*eacl.Table,
return nil, err return nil, err
} }
var cliPrm client.PrmContainerEACL var cliPrm sdkClient.PrmContainerEACL
if prm.cnrID != nil { if prm.cnrID != nil {
cliPrm.SetContainer(*prm.cnrID) cliPrm.SetContainer(*prm.cnrID)
@ -1446,7 +1437,7 @@ func (p *Pool) SetEACL(ctx context.Context, prm PrmContainerSetEACL) error {
return err return err
} }
var cliPrm client.PrmContainerSetEACL var cliPrm sdkClient.PrmContainerSetEACL
if prm.table != nil { if prm.table != nil {
cliPrm.SetTable(*prm.table) cliPrm.SetTable(*prm.table)
@ -1466,7 +1457,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (*accounting.Deci
return nil, err return nil, err
} }
var cliPrm client.PrmBalanceGet var cliPrm sdkClient.PrmBalanceGet
if prm.ownerID != nil { if prm.ownerID != nil {
cliPrm.SetAccount(*prm.ownerID) cliPrm.SetAccount(*prm.ownerID)
@ -1482,7 +1473,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (*accounting.Deci
// WaitForContainerPresence waits until the container is found on the NeoFS network. // WaitForContainerPresence waits until the container is found on the NeoFS network.
func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollParams *ContainerPollingParams) error { func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollParams *ContainerPollingParams) error {
conn, _, err := p.Connection() cp, err := p.connection()
if err != nil { if err != nil {
return err return err
} }
@ -1493,7 +1484,7 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa
wdone := wctx.Done() wdone := wctx.Done()
done := ctx.Done() done := ctx.Done()
var cliPrm client.PrmContainerGet var cliPrm sdkClient.PrmContainerGet
if cid != nil { if cid != nil {
cliPrm.SetContainer(*cid) cliPrm.SetContainer(*cid)
@ -1506,7 +1497,7 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa
case <-wdone: case <-wdone:
return wctx.Err() return wctx.Err()
case <-ticker.C: case <-ticker.C:
_, err = conn.ContainerGet(ctx, cliPrm) _, err = cp.client.ContainerGet(ctx, cliPrm)
if err == nil { if err == nil {
return nil return nil
} }
@ -1515,6 +1506,20 @@ func (p *Pool) WaitForContainerPresence(ctx context.Context, cid *cid.ID, pollPa
} }
} }
func (p *Pool) NetworkInfo(ctx context.Context) (*netmap.NetworkInfo, error) {
cp, err := p.connection()
if err != nil {
return nil, err
}
res, err := cp.client.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{})
if err != nil { // here err already carries both status and client errors
return nil, err
}
return res.Info(), nil
}
// Close closes the Pool and releases all the associated resources. // Close closes the Pool and releases all the associated resources.
func (p *Pool) Close() { func (p *Pool) Close() {
p.cancel() p.cancel()
@ -1522,12 +1527,12 @@ func (p *Pool) Close() {
} }
// creates new session token from SessionCreate call result. // creates new session token from SessionCreate call result.
func (p *Pool) newSessionToken(cliRes *client.ResSessionCreate) *session.Token { func (p *Pool) newSessionToken(cliRes *sdkClient.ResSessionCreate) *session.Token {
return sessionTokenForOwner(p.owner, cliRes) return sessionTokenForOwner(p.owner, cliRes)
} }
// creates new session token with specified owner from SessionCreate call result. // creates new session token with specified owner from SessionCreate call result.
func sessionTokenForOwner(id *owner.ID, cliRes *client.ResSessionCreate) *session.Token { func sessionTokenForOwner(id *owner.ID, cliRes *sdkClient.ResSessionCreate) *session.Token {
st := session.NewToken() st := session.NewToken()
st.SetOwnerID(id) st.SetOwnerID(id)
st.SetID(cliRes.ID()) st.SetID(cliRes.ID())

View file

@ -13,7 +13,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neofs-sdk-go/client" sdkClient "github.com/nspcc-dev/neofs-sdk-go/client"
"github.com/nspcc-dev/neofs-sdk-go/netmap" "github.com/nspcc-dev/neofs-sdk-go/netmap"
"github.com/nspcc-dev/neofs-sdk-go/object" "github.com/nspcc-dev/neofs-sdk-go/object"
"github.com/nspcc-dev/neofs-sdk-go/object/address" "github.com/nspcc-dev/neofs-sdk-go/object/address"
@ -24,7 +24,7 @@ import (
) )
func TestBuildPoolClientFailed(t *testing.T) { func TestBuildPoolClientFailed(t *testing.T) {
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
return nil, fmt.Errorf("error") return nil, fmt.Errorf("error")
} }
@ -46,11 +46,11 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) {
ni := &netmap.NodeInfo{} ni := &netmap.NodeInfo{}
ni.SetAddresses("addr1", "addr2") ni.SetAddresses("addr1", "addr2")
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error session")).AnyTimes() mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error session")).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil return mockClient, nil
} }
@ -83,7 +83,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
var expectedToken *session.Token var expectedToken *session.Token
clientCount := -1 clientCount := -1
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
clientCount++ clientCount++
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockInvokes := 0 mockInvokes := 0
@ -97,12 +97,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
}).AnyTimes() }).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
mockClient2 := NewMockClient(ctrl2) mockClient2 := NewMockClient(ctrl2)
mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
if clientCount == 0 { if clientCount == 0 {
return mockClient, nil return mockClient, nil
@ -130,8 +130,12 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
t.Cleanup(clientPool.Close) t.Cleanup(clientPool.Close)
condition := func() bool { condition := func() bool {
_, st, err := clientPool.Connection() cp, err := clientPool.connection()
return err == nil && st == expectedToken if err != nil {
return false
}
st := clientPool.cache.Get(formCacheKey(cp.address, clientPool.key))
return st == expectedToken
} }
require.Never(t, condition, 900*time.Millisecond, 100*time.Millisecond) require.Never(t, condition, 900*time.Millisecond, 100*time.Millisecond)
require.Eventually(t, condition, 3*time.Second, 300*time.Millisecond) require.Eventually(t, condition, 3*time.Second, 300*time.Millisecond)
@ -155,11 +159,11 @@ func TestOneNode(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
tok.SetID(uid) tok.SetID(uid)
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(tok, nil) mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(tok, nil)
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil return mockClient, nil
} }
@ -175,8 +179,9 @@ func TestOneNode(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(pool.Close) t.Cleanup(pool.Close)
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Equal(t, tok, st) require.Equal(t, tok, st)
} }
@ -186,7 +191,7 @@ func TestTwoNodes(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
var tokens []*session.Token var tokens []*session.Token
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
tok := session.NewToken() tok := session.NewToken()
@ -196,8 +201,8 @@ func TestTwoNodes(t *testing.T) {
tokens = append(tokens, tok) tokens = append(tokens, tok)
return tok, err return tok, err
}) })
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil return mockClient, nil
} }
@ -216,8 +221,9 @@ func TestTwoNodes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(pool.Close) t.Cleanup(pool.Close)
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st) require.Contains(t, tokens, st)
} }
@ -229,7 +235,7 @@ func TestOneOfTwoFailed(t *testing.T) {
var tokens []*session.Token var tokens []*session.Token
clientCount := -1 clientCount := -1
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
clientCount++ clientCount++
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
@ -238,7 +244,7 @@ func TestOneOfTwoFailed(t *testing.T) {
return tok, nil return tok, nil
}).AnyTimes() }).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
mockClient2 := NewMockClient(ctrl2) mockClient2 := NewMockClient(ctrl2)
mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { mockClient2.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
@ -246,10 +252,10 @@ func TestOneOfTwoFailed(t *testing.T) {
tokens = append(tokens, tok) tokens = append(tokens, tok)
return tok, nil return tok, nil
}).AnyTimes() }).AnyTimes()
mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.ResEndpointInfo, error) { mockClient2.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*sdkClient.ResEndpointInfo, error) {
return nil, fmt.Errorf("error") return nil, fmt.Errorf("error")
}).AnyTimes() }).AnyTimes()
mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*client.ResNetworkInfo, error) { mockClient2.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).DoAndReturn(func(_ interface{}, _ ...interface{}) (*sdkClient.ResNetworkInfo, error) {
return nil, fmt.Errorf("error") return nil, fmt.Errorf("error")
}).AnyTimes() }).AnyTimes()
@ -280,8 +286,9 @@ func TestOneOfTwoFailed(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Equal(t, tokens[0], st) require.Equal(t, tokens[0], st)
} }
} }
@ -291,7 +298,7 @@ func TestTwoFailed(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")).AnyTimes()
@ -318,7 +325,7 @@ func TestTwoFailed(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
_, _, err = pool.Connection() _, err = pool.connection()
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "no healthy") require.Contains(t, err.Error(), "no healthy")
} }
@ -329,7 +336,7 @@ func TestSessionCache(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
var tokens []*session.Token var tokens []*session.Token
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
tok := session.NewToken() tok := session.NewToken()
@ -365,8 +372,9 @@ func TestSessionCache(t *testing.T) {
t.Cleanup(pool.Close) t.Cleanup(pool.Close)
// cache must contain session token // cache must contain session token
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st) require.Contains(t, tokens, st)
var prm PrmObjectGet var prm PrmObjectGet
@ -376,8 +384,9 @@ func TestSessionCache(t *testing.T) {
require.Error(t, err) require.Error(t, err)
// cache must not contain session token // cache must not contain session token
_, st, err = pool.Connection() cp, err = pool.connection()
require.NoError(t, err) require.NoError(t, err)
st = pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Nil(t, st) require.Nil(t, st)
var prm2 PrmObjectPut var prm2 PrmObjectPut
@ -387,8 +396,9 @@ func TestSessionCache(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// cache must contain session token // cache must contain session token
_, st, err = pool.Connection() cp, err = pool.connection()
require.NoError(t, err) require.NoError(t, err)
st = pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st) require.Contains(t, tokens, st)
} }
@ -400,7 +410,7 @@ func TestPriority(t *testing.T) {
var tokens []*session.Token var tokens []*session.Token
clientCount := -1 clientCount := -1
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
clientCount++ clientCount++
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
@ -446,13 +456,15 @@ func TestPriority(t *testing.T) {
t.Cleanup(pool.Close) t.Cleanup(pool.Close)
firstNode := func() bool { firstNode := func() bool {
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
return st == tokens[0] return st == tokens[0]
} }
secondNode := func() bool { secondNode := func() bool {
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
return st == tokens[1] return st == tokens[1]
} }
require.Never(t, secondNode, time.Second, 200*time.Millisecond) require.Never(t, secondNode, time.Second, 200*time.Millisecond)
@ -467,7 +479,7 @@ func TestSessionCacheWithKey(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
var tokens []*session.Token var tokens []*session.Token
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) { mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
tok := session.NewToken() tok := session.NewToken()
@ -501,8 +513,9 @@ func TestSessionCacheWithKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// cache must contain session token // cache must contain session token
_, st, err := pool.Connection() cp, err := pool.connection()
require.NoError(t, err) require.NoError(t, err)
st := pool.cache.Get(formCacheKey(cp.address, pool.key))
require.Contains(t, tokens, st) require.Contains(t, tokens, st)
var prm PrmObjectGet var prm PrmObjectGet
@ -525,11 +538,11 @@ func newToken(t *testing.T) *session.Token {
func TestSessionTokenOwner(t *testing.T) { func TestSessionTokenOwner(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
clientBuilder := func(_ string) (Client, error) { clientBuilder := func(_ string) (client, error) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&client.ResSessionCreate{}, nil).AnyTimes() mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(&sdkClient.ResSessionCreate{}, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&client.ResEndpointInfo{}, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResEndpointInfo{}, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
return mockClient, nil return mockClient, nil
} }
@ -569,7 +582,7 @@ func TestWaitPresence(t *testing.T) {
mockClient := NewMockClient(ctrl) mockClient := NewMockClient(ctrl)
mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().EndpointInfo(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&client.ResNetworkInfo{}, nil).AnyTimes() mockClient.EXPECT().NetworkInfo(gomock.Any(), gomock.Any()).Return(&sdkClient.ResNetworkInfo{}, nil).AnyTimes()
mockClient.EXPECT().GetContainer(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockClient.EXPECT().GetContainer(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
cache, err := newCache() cache, err := newCache()

View file

@ -6,7 +6,7 @@ import (
"math/rand" "math/rand"
"testing" "testing"
"github.com/nspcc-dev/neofs-sdk-go/client" sdkClient "github.com/nspcc-dev/neofs-sdk-go/client"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -43,16 +43,16 @@ func TestSamplerStability(t *testing.T) {
} }
type clientMock struct { type clientMock struct {
client.Client sdkClient.Client
name string name string
err error err error
} }
func (c *clientMock) EndpointInfo(context.Context, client.PrmEndpointInfo) (*client.ResEndpointInfo, error) { func (c *clientMock) EndpointInfo(context.Context, sdkClient.PrmEndpointInfo) (*sdkClient.ResEndpointInfo, error) {
return nil, nil return nil, nil
} }
func (c *clientMock) NetworkInfo(context.Context, client.PrmNetworkInfo) (*client.ResNetworkInfo, error) { func (c *clientMock) NetworkInfo(context.Context, sdkClient.PrmNetworkInfo) (*sdkClient.ResNetworkInfo, error) {
return nil, nil return nil, nil
} }
@ -88,16 +88,16 @@ func TestHealthyReweight(t *testing.T) {
} }
// check getting first node connection before rebalance happened // check getting first node connection before rebalance happened
connection0, _, err := p.Connection() connection0, err := p.connection()
require.NoError(t, err) require.NoError(t, err)
mock0 := connection0.(*clientMock) mock0 := connection0.client.(*clientMock)
require.Equal(t, names[0], mock0.name) require.Equal(t, names[0], mock0.name)
p.updateInnerNodesHealth(context.TODO(), 0, buffer) p.updateInnerNodesHealth(context.TODO(), 0, buffer)
connection1, _, err := p.Connection() connection1, err := p.connection()
require.NoError(t, err) require.NoError(t, err)
mock1 := connection1.(*clientMock) mock1 := connection1.client.(*clientMock)
require.Equal(t, names[1], mock1.name) require.Equal(t, names[1], mock1.name)
// enabled first node again // enabled first node again
@ -108,9 +108,9 @@ func TestHealthyReweight(t *testing.T) {
p.updateInnerNodesHealth(context.TODO(), 0, buffer) p.updateInnerNodesHealth(context.TODO(), 0, buffer)
inner.sampler = newSampler(weights, rand.NewSource(0)) inner.sampler = newSampler(weights, rand.NewSource(0))
connection0, _, err = p.Connection() connection0, err = p.connection()
require.NoError(t, err) require.NoError(t, err)
mock0 = connection0.(*clientMock) mock0 = connection0.client.(*clientMock)
require.Equal(t, names[0], mock0.name) require.Equal(t, names[0], mock0.name)
} }