From d00131d6d0097f84d56a2524dff11220e663dc5c Mon Sep 17 00:00:00 2001 From: Alexander Chuprov Date: Wed, 12 Mar 2025 14:59:20 +0300 Subject: [PATCH] =?UTF-8?q?[#346]=20pool:=20'=D0=A1lose'=20waits=20for=20a?= =?UTF-8?q?ll=20client=20operations=20to=20complete?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Alexander Chuprov --- pool/client.go | 16 +++++++ pool/connection_manager.go | 18 +++++++- pool/mock_test.go | 12 ++++-- pool/pool.go | 47 ++++++++++++++++++++- pool/pool_test.go | 85 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 171 insertions(+), 7 deletions(-) diff --git a/pool/client.go b/pool/client.go index f1abbf2..7e4a9ce 100644 --- a/pool/client.go +++ b/pool/client.go @@ -40,6 +40,10 @@ type clientStatusMonitor struct { currentErrorCount uint32 overallErrorCount uint64 methods []*MethodStatus + + // RLock means the client is being used for an operation. + // Lock means the client is marked as closed. + status sync.RWMutex } // values for healthy status of clientStatusMonitor. @@ -446,6 +450,18 @@ func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) return res.Containers(), nil } +func (c *clientStatusMonitor) startOperation() bool { + return c.status.TryRLock() +} + +func (c *clientStatusMonitor) endOperation() { + c.status.RUnlock() +} + +func (c *clientStatusMonitor) markClientClosed() { + c.status.Lock() +} + // PrmListStream groups parameters of ListContainersStream operation. type PrmListStream struct { OwnerID user.ID diff --git a/pool/connection_manager.go b/pool/connection_manager.go index b142529..c4d9a1b 100644 --- a/pool/connection_manager.go +++ b/pool/connection_manager.go @@ -250,12 +250,25 @@ func adjustWeights(weights []float64) []float64 { return adjusted } +func (cm *connectionManager) returnConnection(cp client) { + cp.endOperation() +} + +// connection returns a healthy client. +// User MUST return the used client by calling returnConnection. func (cm *connectionManager) connection() (client, error) { for _, inner := range cm.innerPools { cp, err := inner.connection() - if err == nil { - return cp, nil + if err != nil { + continue } + + if !cp.startOperation() { + cm.log(zap.DebugLevel, "pool contains clients with a closed status") + continue + } + + return cp, nil } return nil, errors.New("no healthy client") @@ -324,6 +337,7 @@ func (cm *connectionManager) close() { // close all clients for _, pools := range cm.innerPools { for _, cli := range pools.clients { + cli.markClientClosed() _ = cli.close() } } diff --git a/pool/mock_test.go b/pool/mock_test.go index 4731108..ce315af 100644 --- a/pool/mock_test.go +++ b/pool/mock_test.go @@ -4,9 +4,6 @@ import ( "context" "crypto/ecdsa" "errors" - - "go.uber.org/zap" - "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/accounting" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/ape" sessionv2 "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/session" @@ -18,12 +15,14 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session" "github.com/google/uuid" + "go.uber.org/zap" ) type mockClient struct { key ecdsa.PrivateKey clientStatusMonitor + endOperationSignal chan interface{} errorOnDial bool errorOnCreateSession bool errorOnEndpointInfo error @@ -88,6 +87,12 @@ func newToken(key ecdsa.PrivateKey) *session.Object { return &tok } +func (m *mockClient) waitExecution() { + if m.endOperationSignal != nil { + <-m.endOperationSignal + } +} + func (m *mockClient) balanceGet(context.Context, PrmBalanceGet) (accounting.Decimal, error) { return accounting.Decimal{}, nil } @@ -168,6 +173,7 @@ func (m *mockClient) objectDelete(context.Context, PrmObjectDelete) error { } func (m *mockClient) objectGet(ctx context.Context, _ PrmObjectGet) (ResGetObject, error) { + m.waitExecution() var res ResGetObject if m.stOnGetObject == nil { diff --git a/pool/pool.go b/pool/pool.go index 2f30ae4..e2fc7e7 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -104,6 +104,12 @@ type clientStatus interface { overallErrorRate() uint64 // methodsStatus returns statistic for all used methods. methodsStatus() []StatusSnapshot + // startOperation increases the counter of ongoing operations. + startOperation() bool + // endOperation decreases the counter of ongoing operations. + endOperation() + // markClientClosed marks the client as closed. + markClientClosed() } // InitParameters contains values used to initialize connection Pool. @@ -982,6 +988,7 @@ type callContext struct { func (p *Pool) initCall(ctxCall *callContext, cfg prmCommon, prmCtx prmContext) error { p.fillAppropriateKey(&cfg) cp, err := p.manager.connection() + if err != nil { return err } @@ -1089,7 +1096,10 @@ func (p *Pool) PatchObject(ctx context.Context, prm PrmObjectPatch) (ResPatchObj prmCtx.useContainer(prm.addr.Container()) var ctxCall callContext - if err := p.initCall(&ctxCall, prm.prmCommon, prmCtx); err != nil { + err := p.initCall(&ctxCall, prm.prmCommon, prmCtx) + defer p.manager.returnConnection(ctxCall.client) + + if err != nil { return ResPatchObject{}, fmt.Errorf("init call context: %w", err) } @@ -1121,7 +1131,10 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (ResPutObject, e var ctxCall callContext ctxCall.sessionClientCut = prm.clientCut - if err := p.initCall(&ctxCall, prm.prmCommon, prmCtx); err != nil { + err := p.initCall(&ctxCall, prm.prmCommon, prmCtx) + defer p.manager.returnConnection(ctxCall.client) + + if err != nil { return ResPutObject{}, fmt.Errorf("init call context: %w", err) } @@ -1178,6 +1191,8 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error { cc.sessionTarget = prm.UseSession err := p.initCall(&cc, prm.prmCommon, prmCtx) + defer p.manager.returnConnection(cc.client) + if err != nil { return err } @@ -1225,6 +1240,8 @@ func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (ResGetObject, e var res ResGetObject err := p.initCall(&cc, prm.prmCommon, prmContext{}) + defer p.manager.returnConnection(cc.client) + if err != nil { return res, err } @@ -1246,6 +1263,8 @@ func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (object.Object var obj object.Object err := p.initCall(&cc, prm.prmCommon, prmContext{}) + defer p.manager.returnConnection(cc.client) + if err != nil { return obj, err } @@ -1293,6 +1312,8 @@ func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRa var res ResObjectRange err := p.initCall(&cc, prm.prmCommon, prmContext{}) + defer p.manager.returnConnection(cc.client) + if err != nil { return res, err } @@ -1383,6 +1404,8 @@ func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (ResObjec // Success can be verified by reading by identifier (see GetContainer). func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return cid.ID{}, err } @@ -1398,6 +1421,8 @@ func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (cid.ID, e // GetContainer reads FrostFS container by ID. func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (container.Container, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return container.Container{}, err } @@ -1413,6 +1438,8 @@ func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (container // ListContainers requests identifiers of the account-owned containers. func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return nil, err } @@ -1429,6 +1456,8 @@ func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid. func (p *Pool) ListContainersStream(ctx context.Context, prm PrmListStream) (ResListStream, error) { var res ResListStream cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return res, err } @@ -1451,6 +1480,8 @@ func (p *Pool) ListContainersStream(ctx context.Context, prm PrmListStream) (Res // Success can be verified by reading by identifier (see GetContainer). func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) error { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return err } @@ -1466,6 +1497,8 @@ func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) erro // AddAPEChain sends a request to set APE chain rules for a target (basically, for a container). func (p *Pool) AddAPEChain(ctx context.Context, prm PrmAddAPEChain) error { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return err } @@ -1481,6 +1514,8 @@ func (p *Pool) AddAPEChain(ctx context.Context, prm PrmAddAPEChain) error { // RemoveAPEChain sends a request to remove APE chain rules for a target. func (p *Pool) RemoveAPEChain(ctx context.Context, prm PrmRemoveAPEChain) error { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return err } @@ -1496,6 +1531,8 @@ func (p *Pool) RemoveAPEChain(ctx context.Context, prm PrmRemoveAPEChain) error // ListAPEChains sends a request to list APE chains rules for a target. func (p *Pool) ListAPEChains(ctx context.Context, prm PrmListAPEChains) ([]ape.Chain, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return nil, err } @@ -1511,6 +1548,8 @@ func (p *Pool) ListAPEChains(ctx context.Context, prm PrmListAPEChains) ([]ape.C // Balance requests current balance of the FrostFS account. func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return accounting.Decimal{}, err } @@ -1570,6 +1609,8 @@ func waitFor(ctx context.Context, params *WaitParams, condition func(context.Con // NetworkInfo requests information about the FrostFS network of which the remote server is a part. func (p *Pool) NetworkInfo(ctx context.Context) (netmap.NetworkInfo, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return netmap.NetworkInfo{}, err } @@ -1585,6 +1626,8 @@ func (p *Pool) NetworkInfo(ctx context.Context) (netmap.NetworkInfo, error) { // NetMapSnapshot requests information about the FrostFS network map. func (p *Pool) NetMapSnapshot(ctx context.Context) (netmap.NetMap, error) { cp, err := p.manager.connection() + defer p.manager.returnConnection(cp) + if err != nil { return netmap.NetMap{}, err } diff --git a/pool/pool_test.go b/pool/pool_test.go index b063294..881e0a8 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "errors" "math/rand" + "sync/atomic" "testing" "time" @@ -105,6 +106,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { expectedAuthKey := frostfsecdsa.PublicKey(clientKeys[1].PublicKey) condition := func() bool { cp, err := clientPool.manager.connection() + defer clientPool.manager.returnConnection(cp) if err != nil { return false } @@ -142,12 +144,83 @@ func TestOneNode(t *testing.T) { t.Cleanup(pool.Close) cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey) require.True(t, st.AssertAuthKey(&expectedAuthKey)) } +func TestWaitAllConnection(t *testing.T) { + ctx := context.Background() + key1 := newPrivateKey(t) + ch := make(chan interface{}) + mockClientBuilder := func(addr string) client { + mock := newMockClient(addr, *key1) + mock.endOperationSignal = ch + return mock + } + + opts := InitParameters{ + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, + } + opts.setClientBuilder(mockClientBuilder) + + pool, err := NewPool(opts) + require.NoError(t, err) + err = pool.Dial(ctx) + require.NoError(t, err) + + var operation, close atomic.Uint64 + go func() { + operation.Store(1) + pool.GetObject(ctx, PrmObjectGet{}) + operation.Store(0) + }() + + require.Eventually(t, func() bool { + if operation.Load() == 1 { + return true + } + return false + }, time.Second, 10*time.Millisecond) + + go func() { + close.Store(1) + pool.Close() + close.Store(0) + }() + + require.Eventually(t, func() bool { + if close.Load() == 1 { + return true + } + return false + }, time.Second, 10*time.Millisecond) + + require.Equal(t, operation.Load(), uint64(1)) + require.Equal(t, close.Load(), uint64(1)) + + ch <- true + + require.Eventually(t, func() bool { + if operation.Load() == 0 { + return true + } + return false + }, time.Second, 10*time.Millisecond) + + require.Equal(t, close.Load(), uint64(0)) + + require.Eventually(t, func() bool { + if close.Load() == 0 { + return true + } + return false + }, time.Second, 10*time.Millisecond) +} + func TestTwoNodes(t *testing.T) { var clientKeys []*ecdsa.PrivateKey mockClientBuilder := func(addr string) client { @@ -172,6 +245,8 @@ func TestTwoNodes(t *testing.T) { t.Cleanup(pool.Close) cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) + require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) @@ -226,6 +301,7 @@ func TestOneOfTwoFailed(t *testing.T) { for range 5 { cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) @@ -469,6 +545,7 @@ func TestSessionCache(t *testing.T) { // cache must contain session token cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) @@ -482,6 +559,7 @@ func TestSessionCache(t *testing.T) { // cache must not contain session token cp, err = pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) _, ok := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.False(t, ok) @@ -494,6 +572,7 @@ func TestSessionCache(t *testing.T) { // cache must contain session token cp, err = pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ = pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) @@ -538,6 +617,7 @@ func TestPriority(t *testing.T) { expectedAuthKey1 := frostfsecdsa.PublicKey(clientKeys[0].PublicKey) firstNode := func() bool { cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) return st.AssertAuthKey(&expectedAuthKey1) @@ -546,6 +626,7 @@ func TestPriority(t *testing.T) { expectedAuthKey2 := frostfsecdsa.PublicKey(clientKeys[1].PublicKey) secondNode := func() bool { cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) return st.AssertAuthKey(&expectedAuthKey2) @@ -583,6 +664,7 @@ func TestSessionCacheWithKey(t *testing.T) { // cache must contain session token cp, err := pool.manager.connection() + defer pool.manager.returnConnection(cp) require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) @@ -636,6 +718,7 @@ func TestSessionTokenOwner(t *testing.T) { tkn = tok } err = p.initCall(&cc, prm, prmCtx) + defer p.manager.returnConnection(cc.client) require.NoError(t, err) err = p.openDefaultSession(ctx, &cc) require.NoError(t, err) @@ -921,6 +1004,7 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { for range errorThreshold { conn, err := pool.manager.connection() + defer pool.manager.returnConnection(conn) require.NoError(t, err) require.Equal(t, nodes[0].address, conn.address()) _, err = conn.objectGet(ctx, PrmObjectGet{}) @@ -928,6 +1012,7 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { } conn, err := pool.manager.connection() + defer pool.manager.returnConnection(conn) require.NoError(t, err) require.Equal(t, nodes[1].address, conn.address()) _, err = conn.objectGet(ctx, PrmObjectGet{})