[#358] pool: Start even if not all node healthy

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
This commit is contained in:
Denis Kirillov 2022-11-03 17:58:38 +03:00 committed by Alex Vanin
parent 7a2a76af95
commit 2eefdab0e4
4 changed files with 300 additions and 164 deletions

View file

@ -23,6 +23,7 @@ type mockClient struct {
key ecdsa.PrivateKey key ecdsa.PrivateKey
clientStatusMonitor clientStatusMonitor
errorOnDial bool
errorOnCreateSession bool errorOnCreateSession bool
errorOnEndpointInfo bool errorOnEndpointInfo bool
errorOnNetworkInfo bool errorOnNetworkInfo bool
@ -52,6 +53,13 @@ func (m *mockClient) errOnNetworkInfo() {
m.errorOnEndpointInfo = true m.errorOnEndpointInfo = true
} }
func (m *mockClient) errOnDial() {
m.errorOnDial = true
m.errOnCreateSession()
m.errOnEndpointInfo()
m.errOnNetworkInfo()
}
func (m *mockClient) statusOnGetObject(st apistatus.Status) { func (m *mockClient) statusOnGetObject(st apistatus.Status) {
m.stOnGetObject = st m.stOnGetObject = st
} }
@ -160,3 +168,22 @@ func (m *mockClient) sessionCreate(context.Context, prmCreateSession) (resCreate
sessionKey: v2tok.GetBody().GetSessionKey(), sessionKey: v2tok.GetBody().GetSessionKey(),
}, nil }, nil
} }
func (m *mockClient) dial(context.Context) error {
if m.errorOnDial {
return errors.New("dial error")
}
return nil
}
func (m *mockClient) restartIfUnhealthy(ctx context.Context) (healthy bool, changed bool) {
_, err := m.endpointInfo(ctx, prmEndpointInfo{})
healthy = err == nil
changed = healthy != m.isHealthy()
if healthy {
m.setHealthy()
} else {
m.setUnhealthy()
}
return
}

View file

@ -70,15 +70,19 @@ type client interface {
sessionCreate(context.Context, prmCreateSession) (resCreateSession, error) sessionCreate(context.Context, prmCreateSession) (resCreateSession, error)
clientStatus clientStatus
// see clientWrapper.dial.
dial(ctx context.Context) error
// see clientWrapper.restartIfUnhealthy.
restartIfUnhealthy(ctx context.Context) (bool, bool)
} }
// clientStatus provide access to some metrics for connection. // clientStatus provide access to some metrics for connection.
type clientStatus interface { type clientStatus interface {
// isHealthy checks if the connection can handle requests. // isHealthy checks if the connection can handle requests.
isHealthy() bool isHealthy() bool
// setHealthy allows set healthy status for connection. // setUnhealthy marks client as unhealthy.
// It's used to update status during Pool.startRebalance routing. setUnhealthy()
setHealthy(bool) bool
// address return address of endpoint. // address return address of endpoint.
address() string address() string
// currentErrorRate returns current errors rate. // currentErrorRate returns current errors rate.
@ -91,6 +95,9 @@ type clientStatus interface {
methodsStatus() []statusSnapshot methodsStatus() []statusSnapshot
} }
// ErrPoolClientUnhealthy is an error to indicate that client in pool is unhealthy.
var ErrPoolClientUnhealthy = errors.New("pool client unhealthy")
// clientStatusMonitor count error rate and other statistics for connection. // clientStatusMonitor count error rate and other statistics for connection.
type clientStatusMonitor struct { type clientStatusMonitor struct {
addr string addr string
@ -207,8 +214,10 @@ func (m *methodStatus) incRequests(elapsed time.Duration) {
// clientWrapper is used by default, alternative implementations are intended for testing purposes only. // clientWrapper is used by default, alternative implementations are intended for testing purposes only.
type clientWrapper struct { type clientWrapper struct {
client sdkClient.Client clientMutex sync.RWMutex
key ecdsa.PrivateKey client *sdkClient.Client
prm wrapperPrm
clientStatusMonitor clientStatusMonitor
} }
@ -219,7 +228,6 @@ type wrapperPrm struct {
timeout time.Duration timeout time.Duration
errorThreshold uint32 errorThreshold uint32
responseInfoCallback func(sdkClient.ResponseMetaInfo) error responseInfoCallback func(sdkClient.ResponseMetaInfo) error
dialCtx context.Context
} }
// setAddress sets endpoint to connect in NeoFS network. // setAddress sets endpoint to connect in NeoFS network.
@ -248,44 +256,107 @@ func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo)
x.responseInfoCallback = f x.responseInfoCallback = f
} }
// setDialContext specifies context for client dial.
func (x *wrapperPrm) setDialContext(ctx context.Context) {
x.dialCtx = ctx
}
// newWrapper creates a clientWrapper that implements the client interface. // newWrapper creates a clientWrapper that implements the client interface.
func newWrapper(prm wrapperPrm) (*clientWrapper, error) { func newWrapper(prm wrapperPrm) *clientWrapper {
var cl sdkClient.Client
var prmInit sdkClient.PrmInit var prmInit sdkClient.PrmInit
prmInit.SetDefaultPrivateKey(prm.key) prmInit.SetDefaultPrivateKey(prm.key)
prmInit.SetResponseInfoCallback(prm.responseInfoCallback) prmInit.SetResponseInfoCallback(prm.responseInfoCallback)
cl.Init(prmInit)
res := &clientWrapper{ res := &clientWrapper{
key: prm.key, client: &cl,
clientStatusMonitor: newClientStatusMonitor(prm.address, prm.errorThreshold), clientStatusMonitor: newClientStatusMonitor(prm.address, prm.errorThreshold),
prm: prm,
} }
res.client.Init(prmInit) return res
}
// dial establishes a connection to the server from the NeoFS network.
// Returns an error describing failure reason. If failed, the client
// SHOULD NOT be used.
func (c *clientWrapper) dial(ctx context.Context) error {
cl, err := c.getClient()
if err != nil {
return err
}
var prmDial sdkClient.PrmDial var prmDial sdkClient.PrmDial
prmDial.SetServerURI(prm.address) prmDial.SetServerURI(c.prm.address)
prmDial.SetTimeout(prm.timeout) prmDial.SetTimeout(c.prm.timeout)
prmDial.SetContext(prm.dialCtx) prmDial.SetContext(ctx)
err := res.client.Dial(prmDial) if err = cl.Dial(prmDial); err != nil {
if err != nil { c.setUnhealthy()
return nil, fmt.Errorf("client dial: %w", err) return err
} }
return res, nil return nil
}
// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy.
// Return current healthy status and indicating if status was changed by this function call.
func (c *clientWrapper) restartIfUnhealthy(ctx context.Context) (healthy, changed bool) {
var wasHealthy bool
if _, err := c.endpointInfo(ctx, prmEndpointInfo{}); err == nil {
return true, false
} else if !errors.Is(err, ErrPoolClientUnhealthy) {
wasHealthy = true
}
var cl sdkClient.Client
var prmInit sdkClient.PrmInit
prmInit.SetDefaultPrivateKey(c.prm.key)
prmInit.SetResponseInfoCallback(c.prm.responseInfoCallback)
cl.Init(prmInit)
var prmDial sdkClient.PrmDial
prmDial.SetServerURI(c.prm.address)
prmDial.SetTimeout(c.prm.timeout)
prmDial.SetContext(ctx)
if err := cl.Dial(prmDial); err != nil {
c.setUnhealthy()
return false, wasHealthy
}
c.clientMutex.Lock()
c.client = &cl
c.clientMutex.Unlock()
if _, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}); err != nil {
c.setUnhealthy()
return false, wasHealthy
}
c.setHealthy()
return true, !wasHealthy
}
func (c *clientWrapper) getClient() (*sdkClient.Client, error) {
c.clientMutex.RLock()
defer c.clientMutex.RUnlock()
if c.isHealthy() {
return c.client, nil
}
return nil, ErrPoolClientUnhealthy
} }
// balanceGet invokes sdkClient.BalanceGet parse response status to error and return result as is. // balanceGet invokes sdkClient.BalanceGet parse response status to error and return result as is.
func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) {
cl, err := c.getClient()
if err != nil {
return accounting.Decimal{}, err
}
var cliPrm sdkClient.PrmBalanceGet var cliPrm sdkClient.PrmBalanceGet
cliPrm.SetAccount(prm.account) cliPrm.SetAccount(prm.account)
start := time.Now() start := time.Now()
res, err := c.client.BalanceGet(ctx, cliPrm) res, err := cl.BalanceGet(ctx, cliPrm)
c.incRequests(time.Since(start), methodBalanceGet) c.incRequests(time.Since(start), methodBalanceGet)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -301,8 +372,13 @@ func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (acco
// containerPut invokes sdkClient.ContainerPut parse response status to error and return result as is. // containerPut invokes sdkClient.ContainerPut parse response status to error and return result as is.
// It also waits for the container to appear on the network. // It also waits for the container to appear on the network.
func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) (cid.ID, error) {
cl, err := c.getClient()
if err != nil {
return cid.ID{}, err
}
start := time.Now() start := time.Now()
res, err := c.client.ContainerPut(ctx, prm.prmClient) res, err := cl.ContainerPut(ctx, prm.prmClient)
c.incRequests(time.Since(start), methodContainerPut) c.incRequests(time.Since(start), methodContainerPut)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -328,11 +404,16 @@ func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) (
// containerGet invokes sdkClient.ContainerGet parse response status to error and return result as is. // containerGet invokes sdkClient.ContainerGet parse response status to error and return result as is.
func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) (container.Container, error) { func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) (container.Container, error) {
cl, err := c.getClient()
if err != nil {
return container.Container{}, err
}
var cliPrm sdkClient.PrmContainerGet var cliPrm sdkClient.PrmContainerGet
cliPrm.SetContainer(prm.cnrID) cliPrm.SetContainer(prm.cnrID)
start := time.Now() start := time.Now()
res, err := c.client.ContainerGet(ctx, cliPrm) res, err := cl.ContainerGet(ctx, cliPrm)
c.incRequests(time.Since(start), methodContainerGet) c.incRequests(time.Since(start), methodContainerGet)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -347,11 +428,16 @@ func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) (
// containerList invokes sdkClient.ContainerList parse response status to error and return result as is. // containerList invokes sdkClient.ContainerList parse response status to error and return result as is.
func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) {
cl, err := c.getClient()
if err != nil {
return nil, err
}
var cliPrm sdkClient.PrmContainerList var cliPrm sdkClient.PrmContainerList
cliPrm.SetAccount(prm.ownerID) cliPrm.SetAccount(prm.ownerID)
start := time.Now() start := time.Now()
res, err := c.client.ContainerList(ctx, cliPrm) res, err := cl.ContainerList(ctx, cliPrm)
c.incRequests(time.Since(start), methodContainerList) c.incRequests(time.Since(start), methodContainerList)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -366,6 +452,11 @@ func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList)
// containerDelete invokes sdkClient.ContainerDelete parse response status to error. // containerDelete invokes sdkClient.ContainerDelete parse response status to error.
// It also waits for the container to be removed from the network. // It also waits for the container to be removed from the network.
func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDelete) error { func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDelete) error {
cl, err := c.getClient()
if err != nil {
return err
}
var cliPrm sdkClient.PrmContainerDelete var cliPrm sdkClient.PrmContainerDelete
cliPrm.SetContainer(prm.cnrID) cliPrm.SetContainer(prm.cnrID)
if prm.stokenSet { if prm.stokenSet {
@ -373,7 +464,7 @@ func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDel
} }
start := time.Now() start := time.Now()
res, err := c.client.ContainerDelete(ctx, cliPrm) res, err := cl.ContainerDelete(ctx, cliPrm)
c.incRequests(time.Since(start), methodContainerDelete) c.incRequests(time.Since(start), methodContainerDelete)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -392,11 +483,16 @@ func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDel
// containerEACL invokes sdkClient.ContainerEACL parse response status to error and return result as is. // containerEACL invokes sdkClient.ContainerEACL parse response status to error and return result as is.
func (c *clientWrapper) containerEACL(ctx context.Context, prm PrmContainerEACL) (eacl.Table, error) { func (c *clientWrapper) containerEACL(ctx context.Context, prm PrmContainerEACL) (eacl.Table, error) {
cl, err := c.getClient()
if err != nil {
return eacl.Table{}, err
}
var cliPrm sdkClient.PrmContainerEACL var cliPrm sdkClient.PrmContainerEACL
cliPrm.SetContainer(prm.cnrID) cliPrm.SetContainer(prm.cnrID)
start := time.Now() start := time.Now()
res, err := c.client.ContainerEACL(ctx, cliPrm) res, err := cl.ContainerEACL(ctx, cliPrm)
c.incRequests(time.Since(start), methodContainerEACL) c.incRequests(time.Since(start), methodContainerEACL)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -412,6 +508,11 @@ func (c *clientWrapper) containerEACL(ctx context.Context, prm PrmContainerEACL)
// containerSetEACL invokes sdkClient.ContainerSetEACL parse response status to error. // containerSetEACL invokes sdkClient.ContainerSetEACL parse response status to error.
// It also waits for the EACL to appear on the network. // It also waits for the EACL to appear on the network.
func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSetEACL) error { func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSetEACL) error {
cl, err := c.getClient()
if err != nil {
return err
}
var cliPrm sdkClient.PrmContainerSetEACL var cliPrm sdkClient.PrmContainerSetEACL
cliPrm.SetTable(prm.table) cliPrm.SetTable(prm.table)
@ -420,7 +521,7 @@ func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSe
} }
start := time.Now() start := time.Now()
res, err := c.client.ContainerSetEACL(ctx, cliPrm) res, err := cl.ContainerSetEACL(ctx, cliPrm)
c.incRequests(time.Since(start), methodContainerSetEACL) c.incRequests(time.Since(start), methodContainerSetEACL)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -449,8 +550,13 @@ func (c *clientWrapper) containerSetEACL(ctx context.Context, prm PrmContainerSe
// endpointInfo invokes sdkClient.EndpointInfo parse response status to error and return result as is. // endpointInfo invokes sdkClient.EndpointInfo parse response status to error and return result as is.
func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) { func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) {
cl, err := c.getClient()
if err != nil {
return netmap.NodeInfo{}, err
}
start := time.Now() start := time.Now()
res, err := c.client.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}) res, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{})
c.incRequests(time.Since(start), methodEndpointInfo) c.incRequests(time.Since(start), methodEndpointInfo)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -465,8 +571,13 @@ func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (ne
// networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is. // networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is.
func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) { func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) {
cl, err := c.getClient()
if err != nil {
return netmap.NetworkInfo{}, err
}
start := time.Now() start := time.Now()
res, err := c.client.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) res, err := cl.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{})
c.incRequests(time.Since(start), methodNetworkInfo) c.incRequests(time.Since(start), methodNetworkInfo)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -481,6 +592,11 @@ func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netm
// objectPut writes object to NeoFS. // objectPut writes object to NeoFS.
func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID, error) { func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID, error) {
cl, err := c.getClient()
if err != nil {
return oid.ID{}, err
}
var cliPrm sdkClient.PrmObjectPutInit var cliPrm sdkClient.PrmObjectPutInit
cliPrm.SetCopiesNumber(prm.copiesNumber) cliPrm.SetCopiesNumber(prm.copiesNumber)
if prm.stoken != nil { if prm.stoken != nil {
@ -494,7 +610,7 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID
} }
start := time.Now() start := time.Now()
wObj, err := c.client.ObjectPutInit(ctx, cliPrm) wObj, err := cl.ObjectPutInit(ctx, cliPrm)
c.incRequests(time.Since(start), methodObjectPut) c.incRequests(time.Since(start), methodObjectPut)
if err = c.handleError(nil, err); err != nil { if err = c.handleError(nil, err); err != nil {
return oid.ID{}, fmt.Errorf("init writing on API client: %w", err) return oid.ID{}, fmt.Errorf("init writing on API client: %w", err)
@ -559,6 +675,11 @@ func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (oid.ID
// objectDelete invokes sdkClient.ObjectDelete parse response status to error. // objectDelete invokes sdkClient.ObjectDelete parse response status to error.
func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) error { func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) error {
cl, err := c.getClient()
if err != nil {
return err
}
var cliPrm sdkClient.PrmObjectDelete var cliPrm sdkClient.PrmObjectDelete
cliPrm.FromContainer(prm.addr.Container()) cliPrm.FromContainer(prm.addr.Container())
cliPrm.ByID(prm.addr.Object()) cliPrm.ByID(prm.addr.Object())
@ -576,7 +697,7 @@ func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) e
} }
start := time.Now() start := time.Now()
res, err := c.client.ObjectDelete(ctx, cliPrm) res, err := cl.ObjectDelete(ctx, cliPrm)
c.incRequests(time.Since(start), methodObjectDelete) c.incRequests(time.Since(start), methodObjectDelete)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -590,6 +711,11 @@ func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) e
// objectGet returns reader for object. // objectGet returns reader for object.
func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) { func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) {
cl, err := c.getClient()
if err != nil {
return ResGetObject{}, err
}
var cliPrm sdkClient.PrmObjectGet var cliPrm sdkClient.PrmObjectGet
cliPrm.FromContainer(prm.addr.Container()) cliPrm.FromContainer(prm.addr.Container())
cliPrm.ByID(prm.addr.Object()) cliPrm.ByID(prm.addr.Object())
@ -608,7 +734,7 @@ func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGet
var res ResGetObject var res ResGetObject
rObj, err := c.client.ObjectGetInit(ctx, cliPrm) rObj, err := cl.ObjectGetInit(ctx, cliPrm)
if err = c.handleError(nil, err); err != nil { if err = c.handleError(nil, err); err != nil {
return ResGetObject{}, fmt.Errorf("init object reading on client: %w", err) return ResGetObject{}, fmt.Errorf("init object reading on client: %w", err)
} }
@ -638,6 +764,11 @@ func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGet
// objectHead invokes sdkClient.ObjectHead parse response status to error and return result as is. // objectHead invokes sdkClient.ObjectHead parse response status to error and return result as is.
func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (object.Object, error) { func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (object.Object, error) {
cl, err := c.getClient()
if err != nil {
return object.Object{}, err
}
var cliPrm sdkClient.PrmObjectHead var cliPrm sdkClient.PrmObjectHead
cliPrm.FromContainer(prm.addr.Container()) cliPrm.FromContainer(prm.addr.Container())
cliPrm.ByID(prm.addr.Object()) cliPrm.ByID(prm.addr.Object())
@ -657,7 +788,7 @@ func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (obje
var obj object.Object var obj object.Object
start := time.Now() start := time.Now()
res, err := c.client.ObjectHead(ctx, cliPrm) res, err := cl.ObjectHead(ctx, cliPrm)
c.incRequests(time.Since(start), methodObjectHead) c.incRequests(time.Since(start), methodObjectHead)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -675,6 +806,11 @@ func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (obje
// objectRange returns object range reader. // objectRange returns object range reader.
func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) { func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) {
cl, err := c.getClient()
if err != nil {
return ResObjectRange{}, err
}
var cliPrm sdkClient.PrmObjectRange var cliPrm sdkClient.PrmObjectRange
cliPrm.FromContainer(prm.addr.Container()) cliPrm.FromContainer(prm.addr.Container())
cliPrm.ByID(prm.addr.Object()) cliPrm.ByID(prm.addr.Object())
@ -694,7 +830,7 @@ func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (Re
} }
start := time.Now() start := time.Now()
res, err := c.client.ObjectRangeInit(ctx, cliPrm) res, err := cl.ObjectRangeInit(ctx, cliPrm)
c.incRequests(time.Since(start), methodObjectRange) c.incRequests(time.Since(start), methodObjectRange)
if err = c.handleError(nil, err); err != nil { if err = c.handleError(nil, err); err != nil {
return ResObjectRange{}, fmt.Errorf("init payload range reading on client: %w", err) return ResObjectRange{}, fmt.Errorf("init payload range reading on client: %w", err)
@ -710,6 +846,11 @@ func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (Re
// objectSearch invokes sdkClient.ObjectSearchInit parse response status to error and return result as is. // objectSearch invokes sdkClient.ObjectSearchInit parse response status to error and return result as is.
func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) { func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) {
cl, err := c.getClient()
if err != nil {
return ResObjectSearch{}, err
}
var cliPrm sdkClient.PrmObjectSearch var cliPrm sdkClient.PrmObjectSearch
cliPrm.InContainer(prm.cnrID) cliPrm.InContainer(prm.cnrID)
@ -727,7 +868,7 @@ func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (
cliPrm.UseKey(*prm.key) cliPrm.UseKey(*prm.key)
} }
res, err := c.client.ObjectSearchInit(ctx, cliPrm) res, err := cl.ObjectSearchInit(ctx, cliPrm)
if err = c.handleError(nil, err); err != nil { if err = c.handleError(nil, err); err != nil {
return ResObjectSearch{}, fmt.Errorf("init object searching on client: %w", err) return ResObjectSearch{}, fmt.Errorf("init object searching on client: %w", err)
} }
@ -737,12 +878,17 @@ func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (
// sessionCreate invokes sdkClient.SessionCreate parse response status to error and return result as is. // sessionCreate invokes sdkClient.SessionCreate parse response status to error and return result as is.
func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) (resCreateSession, error) { func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) (resCreateSession, error) {
cl, err := c.getClient()
if err != nil {
return resCreateSession{}, err
}
var cliPrm sdkClient.PrmSessionCreate var cliPrm sdkClient.PrmSessionCreate
cliPrm.SetExp(prm.exp) cliPrm.SetExp(prm.exp)
cliPrm.UseKey(prm.key) cliPrm.UseKey(prm.key)
start := time.Now() start := time.Now()
res, err := c.client.SessionCreate(ctx, cliPrm) res, err := cl.SessionCreate(ctx, cliPrm)
c.incRequests(time.Since(start), methodSessionCreate) c.incRequests(time.Since(start), methodSessionCreate)
var st apistatus.Status var st apistatus.Status
if res != nil { if res != nil {
@ -762,8 +908,12 @@ func (c *clientStatusMonitor) isHealthy() bool {
return c.healthy.Load() return c.healthy.Load()
} }
func (c *clientStatusMonitor) setHealthy(val bool) bool { func (c *clientStatusMonitor) setHealthy() {
return c.healthy.Swap(val) != val c.healthy.Store(true)
}
func (c *clientStatusMonitor) setUnhealthy() {
c.healthy.Store(false)
} }
func (c *clientStatusMonitor) address() string { func (c *clientStatusMonitor) address() string {
@ -776,7 +926,7 @@ func (c *clientStatusMonitor) incErrorRate() {
c.currentErrorCount++ c.currentErrorCount++
c.overallErrorCount++ c.overallErrorCount++
if c.currentErrorCount >= c.errorThreshold { if c.currentErrorCount >= c.errorThreshold {
c.setHealthy(false) c.setUnhealthy()
c.currentErrorCount = 0 c.currentErrorCount = 0
} }
} }
@ -827,11 +977,7 @@ func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error
// clientBuilder is a type alias of client constructors which open connection // clientBuilder is a type alias of client constructors which open connection
// to the given endpoint. // to the given endpoint.
type clientBuilder = func(endpoint string) (client, error) type clientBuilder = func(endpoint string) client
// clientBuilderContext is a type alias of client constructors which open
// connection to the given endpoint using provided context.
type clientBuilderContext = func(ctx context.Context, endpoint string) (client, error)
// InitParameters contains values used to initialize connection Pool. // InitParameters contains values used to initialize connection Pool.
type InitParameters struct { type InitParameters struct {
@ -844,7 +990,7 @@ type InitParameters struct {
errorThreshold uint32 errorThreshold uint32
nodeParams []NodeParam nodeParams []NodeParam
clientBuilder clientBuilderContext clientBuilder clientBuilder
} }
// SetKey specifies default key to be used for the protocol communication by default. // SetKey specifies default key to be used for the protocol communication by default.
@ -894,13 +1040,6 @@ func (x *InitParameters) AddNode(nodeParam NodeParam) {
// setClientBuilder sets clientBuilder used for client construction. // setClientBuilder sets clientBuilder used for client construction.
// Wraps setClientBuilderContext without a context. // Wraps setClientBuilderContext without a context.
func (x *InitParameters) setClientBuilder(builder clientBuilder) { func (x *InitParameters) setClientBuilder(builder clientBuilder) {
x.setClientBuilderContext(func(_ context.Context, endpoint string) (client, error) {
return builder(endpoint)
})
}
// setClientBuilderContext sets clientBuilderContext used for client construction.
func (x *InitParameters) setClientBuilderContext(builder clientBuilderContext) {
x.clientBuilder = builder x.clientBuilder = builder
} }
@ -1336,7 +1475,7 @@ type Pool struct {
cache *sessionCache cache *sessionCache
stokenDuration uint64 stokenDuration uint64
rebalanceParams rebalanceParameters rebalanceParams rebalanceParameters
clientBuilder clientBuilderContext clientBuilder clientBuilder
logger *zap.Logger logger *zap.Logger
} }
@ -1404,22 +1543,26 @@ func (p *Pool) Dial(ctx context.Context) error {
for i, params := range p.rebalanceParams.nodesParams { for i, params := range p.rebalanceParams.nodesParams {
clients := make([]client, len(params.weights)) clients := make([]client, len(params.weights))
for j, addr := range params.addresses { for j, addr := range params.addresses {
c, err := p.clientBuilder(ctx, addr) c := p.clientBuilder(addr)
if err != nil { if err := c.dial(ctx); err != nil {
return err if p.logger != nil {
p.logger.Warn("failed to build client", zap.String("address", addr), zap.Error(err))
}
} }
var healthy bool
var st session.Object var st session.Object
err = initSessionForDuration(ctx, &st, c, p.rebalanceParams.sessionExpirationDuration, *p.key) err := initSessionForDuration(ctx, &st, c, p.rebalanceParams.sessionExpirationDuration, *p.key)
if err != nil && p.logger != nil { if err != nil {
p.logger.Warn("failed to create neofs session token for client", c.setUnhealthy()
zap.String("Address", addr), if p.logger != nil {
zap.Error(err)) p.logger.Warn("failed to create neofs session token for client",
} else if err == nil { zap.String("address", addr), zap.Error(err))
healthy, atLeastOneHealthy = true, true }
} else {
atLeastOneHealthy = true
_ = p.cache.Put(formCacheKey(addr, p.key), st) _ = p.cache.Put(formCacheKey(addr, p.key), st)
} }
c.setHealthy(healthy)
clients[j] = c clients[j] = c
} }
source := rand.NewSource(time.Now().UnixNano()) source := rand.NewSource(time.Now().UnixNano())
@ -1462,7 +1605,7 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) {
} }
if params.isMissingClientBuilder() { if params.isMissingClientBuilder() {
params.setClientBuilderContext(func(ctx context.Context, addr string) (client, error) { params.setClientBuilder(func(addr string) client {
var prm wrapperPrm var prm wrapperPrm
prm.setAddress(addr) prm.setAddress(addr)
prm.setKey(*params.key) prm.setKey(*params.key)
@ -1472,7 +1615,6 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) {
cache.updateEpoch(info.Epoch()) cache.updateEpoch(info.Epoch())
return nil return nil
}) })
prm.setDialContext(ctx)
return newWrapper(prm) return newWrapper(prm)
}) })
} }
@ -1551,29 +1693,23 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights
healthyChanged := atomic.NewBool(false) healthyChanged := atomic.NewBool(false)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
var prmEndpoint prmEndpointInfo
for j, cli := range pool.clients { for j, cli := range pool.clients {
wg.Add(1) wg.Add(1)
go func(j int, cli client) { go func(j int, cli client) {
defer wg.Done() defer wg.Done()
ok := true
tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout)
defer c() defer c()
// TODO (@kirillovdenis) : #283 consider reconnect to the node on failure healthy, changed := cli.restartIfUnhealthy(tctx)
if _, err := cli.endpointInfo(tctx, prmEndpoint); err != nil { if healthy {
ok = false
bufferWeights[j] = 0
}
if ok {
bufferWeights[j] = options.nodesParams[i].weights[j] bufferWeights[j] = options.nodesParams[i].weights[j]
} else { } else {
bufferWeights[j] = 0
p.cache.DeleteByPrefix(cli.address()) p.cache.DeleteByPrefix(cli.address())
} }
if cli.setHealthy(ok) { if changed {
healthyChanged.Store(true) healthyChanged.Store(true)
} }
}(j, cli) }(j, cli)
@ -1616,7 +1752,7 @@ func (p *Pool) connection() (client, error) {
} }
func (p *innerPool) connection() (client, error) { func (p *innerPool) connection() (client, error) {
p.lock.RLock() // TODO(@kirillovdenis): #283 consider remove this lock because using client should be thread safe p.lock.RLock() // need lock because of using p.sampler
defer p.lock.RUnlock() defer p.lock.RUnlock()
if len(p.clients) == 1 { if len(p.clients) == 1 {
cp := p.clients[0] cp := p.clients[0]

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"fmt"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -22,15 +21,17 @@ import (
) )
func TestBuildPoolClientFailed(t *testing.T) { func TestBuildPoolClientFailed(t *testing.T) {
clientBuilder := func(string) (client, error) { mockClientBuilder := func(addr string) client {
return nil, fmt.Errorf("error") mockCli := newMockClient(addr, *newPrivateKey(t))
mockCli.errOnDial()
return mockCli
} }
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}}, nodeParams: []NodeParam{{1, "peer0", 1}},
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -39,17 +40,17 @@ func TestBuildPoolClientFailed(t *testing.T) {
} }
func TestBuildPoolCreateSessionFailed(t *testing.T) { func TestBuildPoolCreateSessionFailed(t *testing.T) {
clientBuilder := func(addr string) (client, error) { clientMockBuilder := func(addr string) client {
mockCli := newMockClient(addr, *newPrivateKey(t)) mockCli := newMockClient(addr, *newPrivateKey(t))
mockCli.errOnCreateSession() mockCli.errOnCreateSession()
return mockCli, nil return mockCli
} }
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}}, nodeParams: []NodeParam{{1, "peer0", 1}},
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(clientMockBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -70,17 +71,17 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
} }
var clientKeys []*ecdsa.PrivateKey var clientKeys []*ecdsa.PrivateKey
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
clientKeys = append(clientKeys, key) clientKeys = append(clientKeys, key)
if addr == nodes[0].address { if addr == nodes[0].address {
mockCli := newMockClient(addr, *key) mockCli := newMockClient(addr, *key)
mockCli.errOnEndpointInfo() mockCli.errOnEndpointInfo()
return mockCli, nil return mockCli
} }
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
log, err := zap.NewProduction() log, err := zap.NewProduction()
@ -91,7 +92,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
logger: log, logger: log,
nodeParams: nodes, nodeParams: nodes,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
clientPool, err := NewPool(opts) clientPool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -122,15 +123,15 @@ func TestBuildPoolZeroNodes(t *testing.T) {
func TestOneNode(t *testing.T) { func TestOneNode(t *testing.T) {
key1 := newPrivateKey(t) key1 := newPrivateKey(t)
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
return newMockClient(addr, *key1), nil return newMockClient(addr, *key1)
} }
opts := InitParameters{ opts := InitParameters{
key: newPrivateKey(t), key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}}, nodeParams: []NodeParam{{1, "peer0", 1}},
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -147,10 +148,10 @@ func TestOneNode(t *testing.T) {
func TestTwoNodes(t *testing.T) { func TestTwoNodes(t *testing.T) {
var clientKeys []*ecdsa.PrivateKey var clientKeys []*ecdsa.PrivateKey
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
clientKeys = append(clientKeys, key) clientKeys = append(clientKeys, key)
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
opts := InitParameters{ opts := InitParameters{
@ -160,7 +161,7 @@ func TestTwoNodes(t *testing.T) {
{1, "peer1", 1}, {1, "peer1", 1},
}, },
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -191,18 +192,18 @@ func TestOneOfTwoFailed(t *testing.T) {
} }
var clientKeys []*ecdsa.PrivateKey var clientKeys []*ecdsa.PrivateKey
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
clientKeys = append(clientKeys, key) clientKeys = append(clientKeys, key)
if addr == nodes[0].address { if addr == nodes[0].address {
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
mockCli := newMockClient(addr, *key) mockCli := newMockClient(addr, *key)
mockCli.errOnEndpointInfo() mockCli.errOnEndpointInfo()
mockCli.errOnNetworkInfo() mockCli.errOnNetworkInfo()
return mockCli, nil return mockCli
} }
opts := InitParameters{ opts := InitParameters{
@ -210,7 +211,7 @@ func TestOneOfTwoFailed(t *testing.T) {
nodeParams: nodes, nodeParams: nodes,
clientRebalanceInterval: 200 * time.Millisecond, clientRebalanceInterval: 200 * time.Millisecond,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -232,12 +233,12 @@ func TestOneOfTwoFailed(t *testing.T) {
func TestTwoFailed(t *testing.T) { func TestTwoFailed(t *testing.T) {
var clientKeys []*ecdsa.PrivateKey var clientKeys []*ecdsa.PrivateKey
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
clientKeys = append(clientKeys, key) clientKeys = append(clientKeys, key)
mockCli := newMockClient(addr, *key) mockCli := newMockClient(addr, *key)
mockCli.errOnEndpointInfo() mockCli.errOnEndpointInfo()
return mockCli, nil return mockCli
} }
opts := InitParameters{ opts := InitParameters{
@ -248,7 +249,7 @@ func TestTwoFailed(t *testing.T) {
}, },
clientRebalanceInterval: 200 * time.Millisecond, clientRebalanceInterval: 200 * time.Millisecond,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
pool, err := NewPool(opts) pool, err := NewPool(opts)
require.NoError(t, err) require.NoError(t, err)
@ -268,10 +269,10 @@ func TestSessionCache(t *testing.T) {
key := newPrivateKey(t) key := newPrivateKey(t)
expectedAuthKey := neofsecdsa.PublicKey(key.PublicKey) expectedAuthKey := neofsecdsa.PublicKey(key.PublicKey)
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
mockCli := newMockClient(addr, *key) mockCli := newMockClient(addr, *key)
mockCli.statusOnGetObject(apistatus.SessionTokenNotFound{}) mockCli.statusOnGetObject(apistatus.SessionTokenNotFound{})
return mockCli, nil return mockCli
} }
opts := InitParameters{ opts := InitParameters{
@ -281,7 +282,7 @@ func TestSessionCache(t *testing.T) {
}, },
clientRebalanceInterval: 30 * time.Second, clientRebalanceInterval: 30 * time.Second,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -331,17 +332,17 @@ func TestPriority(t *testing.T) {
} }
var clientKeys []*ecdsa.PrivateKey var clientKeys []*ecdsa.PrivateKey
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
clientKeys = append(clientKeys, key) clientKeys = append(clientKeys, key)
if addr == nodes[0].address { if addr == nodes[0].address {
mockCli := newMockClient(addr, *key) mockCli := newMockClient(addr, *key)
mockCli.errOnEndpointInfo() mockCli.errOnEndpointInfo()
return mockCli, nil return mockCli
} }
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
opts := InitParameters{ opts := InitParameters{
@ -349,7 +350,7 @@ func TestPriority(t *testing.T) {
nodeParams: nodes, nodeParams: nodes,
clientRebalanceInterval: 1500 * time.Millisecond, clientRebalanceInterval: 1500 * time.Millisecond,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -385,8 +386,8 @@ func TestSessionCacheWithKey(t *testing.T) {
key := newPrivateKey(t) key := newPrivateKey(t)
expectedAuthKey := neofsecdsa.PublicKey(key.PublicKey) expectedAuthKey := neofsecdsa.PublicKey(key.PublicKey)
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
opts := InitParameters{ opts := InitParameters{
@ -396,7 +397,7 @@ func TestSessionCacheWithKey(t *testing.T) {
}, },
clientRebalanceInterval: 30 * time.Second, clientRebalanceInterval: 30 * time.Second,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -424,9 +425,9 @@ func TestSessionCacheWithKey(t *testing.T) {
} }
func TestSessionTokenOwner(t *testing.T) { func TestSessionTokenOwner(t *testing.T) {
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
opts := InitParameters{ opts := InitParameters{
@ -435,7 +436,7 @@ func TestSessionTokenOwner(t *testing.T) {
{1, "peer0", 1}, {1, "peer0", 1},
}, },
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -620,7 +621,7 @@ func TestSwitchAfterErrorThreshold(t *testing.T) {
errorThreshold := 5 errorThreshold := 5
var clientKeys []*ecdsa.PrivateKey var clientKeys []*ecdsa.PrivateKey
clientBuilder := func(addr string) (client, error) { mockClientBuilder := func(addr string) client {
key := newPrivateKey(t) key := newPrivateKey(t)
clientKeys = append(clientKeys, key) clientKeys = append(clientKeys, key)
@ -628,10 +629,10 @@ func TestSwitchAfterErrorThreshold(t *testing.T) {
mockCli := newMockClient(addr, *key) mockCli := newMockClient(addr, *key)
mockCli.setThreshold(uint32(errorThreshold)) mockCli.setThreshold(uint32(errorThreshold))
mockCli.statusOnGetObject(apistatus.ServerInternal{}) mockCli.statusOnGetObject(apistatus.ServerInternal{})
return mockCli, nil return mockCli
} }
return newMockClient(addr, *key), nil return newMockClient(addr, *key)
} }
opts := InitParameters{ opts := InitParameters{
@ -639,7 +640,7 @@ func TestSwitchAfterErrorThreshold(t *testing.T) {
nodeParams: nodes, nodeParams: nodes,
clientRebalanceInterval: 30 * time.Second, clientRebalanceInterval: 30 * time.Second,
} }
opts.setClientBuilder(clientBuilder) opts.setClientBuilder(mockClientBuilder)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()

View file

@ -2,11 +2,9 @@ package pool
import ( import (
"context" "context"
"fmt"
"math/rand" "math/rand"
"testing" "testing"
"github.com/nspcc-dev/neofs-sdk-go/netmap"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -42,34 +40,6 @@ func TestSamplerStability(t *testing.T) {
} }
} }
type clientMock struct {
clientWrapper
name string
err error
}
func (c *clientMock) endpointInfo(context.Context, prmEndpointInfo) (netmap.NodeInfo, error) {
return netmap.NodeInfo{}, nil
}
func (c *clientMock) networkInfo(context.Context, prmNetworkInfo) (netmap.NetworkInfo, error) {
return netmap.NetworkInfo{}, nil
}
func newNetmapMock(name string, needErr bool) *clientMock {
var err error
if needErr {
err = fmt.Errorf("not available")
}
return &clientMock{
clientWrapper: clientWrapper{
clientStatusMonitor: newClientStatusMonitor("", 10),
},
name: name,
err: err,
}
}
func TestHealthyReweight(t *testing.T) { func TestHealthyReweight(t *testing.T) {
var ( var (
weights = []float64{0.9, 0.1} weights = []float64{0.9, 0.1}
@ -80,12 +50,14 @@ func TestHealthyReweight(t *testing.T) {
cache, err := newCache() cache, err := newCache()
require.NoError(t, err) require.NoError(t, err)
client1 := newMockClient(names[0], *newPrivateKey(t))
client1.errOnDial()
client2 := newMockClient(names[1], *newPrivateKey(t))
inner := &innerPool{ inner := &innerPool{
sampler: newSampler(weights, rand.NewSource(0)), sampler: newSampler(weights, rand.NewSource(0)),
clients: []client{ clients: []client{client1, client2},
newNetmapMock(names[0], true),
newNetmapMock(names[1], false),
},
} }
p := &Pool{ p := &Pool{
innerPools: []*innerPool{inner}, innerPools: []*innerPool{inner},
@ -97,19 +69,19 @@ func TestHealthyReweight(t *testing.T) {
// check getting first node connection before rebalance happened // check getting first node connection before rebalance happened
connection0, err := p.connection() connection0, err := p.connection()
require.NoError(t, err) require.NoError(t, err)
mock0 := connection0.(*clientMock) mock0 := connection0.(*mockClient)
require.Equal(t, names[0], mock0.name) require.Equal(t, names[0], mock0.address())
p.updateInnerNodesHealth(context.TODO(), 0, buffer) p.updateInnerNodesHealth(context.TODO(), 0, buffer)
connection1, err := p.connection() connection1, err := p.connection()
require.NoError(t, err) require.NoError(t, err)
mock1 := connection1.(*clientMock) mock1 := connection1.(*mockClient)
require.Equal(t, names[1], mock1.name) require.Equal(t, names[1], mock1.address())
// enabled first node again // enabled first node again
inner.lock.Lock() inner.lock.Lock()
inner.clients[0] = newNetmapMock(names[0], false) inner.clients[0] = newMockClient(names[0], *newPrivateKey(t))
inner.lock.Unlock() inner.lock.Unlock()
p.updateInnerNodesHealth(context.TODO(), 0, buffer) p.updateInnerNodesHealth(context.TODO(), 0, buffer)
@ -117,8 +89,8 @@ func TestHealthyReweight(t *testing.T) {
connection0, err = p.connection() connection0, err = p.connection()
require.NoError(t, err) require.NoError(t, err)
mock0 = connection0.(*clientMock) mock0 = connection0.(*mockClient)
require.Equal(t, names[0], mock0.name) require.Equal(t, names[0], mock0.address())
} }
func TestHealthyNoReweight(t *testing.T) { func TestHealthyNoReweight(t *testing.T) {
@ -132,8 +104,8 @@ func TestHealthyNoReweight(t *testing.T) {
inner := &innerPool{ inner := &innerPool{
sampler: sampl, sampler: sampl,
clients: []client{ clients: []client{
newNetmapMock(names[0], false), newMockClient(names[0], *newPrivateKey(t)),
newNetmapMock(names[1], false), newMockClient(names[1], *newPrivateKey(t)),
}, },
} }
p := &Pool{ p := &Pool{