From d46f1d3bfaa6b3c964b9c7bf7cb6b8fd819c6406 Mon Sep 17 00:00:00 2001 From: Nikita Zinkevich Date: Tue, 10 Dec 2024 13:28:23 +0300 Subject: [PATCH] [#569] Support context cancellation in tree node streaming Signed-off-by: Nikita Zinkevich --- api/data/tree.go | 2 +- internal/frostfs/services/pool_wrapper.go | 9 ++- pkg/service/tree/tree.go | 22 +++---- pkg/service/tree/tree_client_in_memory.go | 7 +- pkg/service/tree/tree_test.go | 80 +++++++++++++++++++++++ 5 files changed, 105 insertions(+), 15 deletions(-) diff --git a/api/data/tree.go b/api/data/tree.go index c75d936b..c796d310 100644 --- a/api/data/tree.go +++ b/api/data/tree.go @@ -62,7 +62,7 @@ func (e ExtendedObjectInfo) Version() string { // Basically used for "system" object. type BaseNodeVersion struct { ID uint64 - ParenID uint64 + ParentID uint64 OID oid.ID Timestamp uint64 Size uint64 diff --git a/internal/frostfs/services/pool_wrapper.go b/internal/frostfs/services/pool_wrapper.go index 82652cbb..0928d06a 100644 --- a/internal/frostfs/services/pool_wrapper.go +++ b/internal/frostfs/services/pool_wrapper.go @@ -152,7 +152,12 @@ type SubTreeStreamImpl struct { const bufSize = 1000 -func (s *SubTreeStreamImpl) Next() (tree.NodeResponse, error) { +func (s *SubTreeStreamImpl) Next(ctx context.Context) (tree.NodeResponse, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } if s.index != -1 { node := s.buffer[s.index] s.index++ @@ -177,7 +182,7 @@ func (s *SubTreeStreamImpl) Next() (tree.NodeResponse, error) { s.index = 0 } - return s.Next() + return s.Next(ctx) } func (w *PoolWrapper) GetSubTreeStream(ctx context.Context, bktInfo *data.BucketInfo, treeID string, rootID []uint64, depth uint32) (tree.SubTreeStream, error) { diff --git a/pkg/service/tree/tree.go b/pkg/service/tree/tree.go index 438b1297..862c0a3c 100644 --- a/pkg/service/tree/tree.go +++ b/pkg/service/tree/tree.go @@ -43,7 +43,7 @@ type ( } SubTreeStream interface { - Next() (NodeResponse, error) + Next(ctx context.Context) (NodeResponse, error) } treeNode struct { @@ -242,7 +242,7 @@ func newNodeVersionFromTreeNode(log *zap.Logger, filePath string, treeNode *tree version := &data.NodeVersion{ BaseNodeVersion: data.BaseNodeVersion{ ID: treeNode.ID[0], - ParenID: treeNode.ParentID[0], + ParentID: treeNode.ParentID[0], OID: treeNode.ObjID, Timestamp: treeNode.TimeStamp[0], ETag: eTag, @@ -891,7 +891,7 @@ type DummySubTreeStream struct { read bool } -func (s *DummySubTreeStream) Next() (NodeResponse, error) { +func (s *DummySubTreeStream) Next(context.Context) (NodeResponse, error) { if s.read { return nil, io.EOF } @@ -935,14 +935,14 @@ type VersionsByPrefixStreamImpl struct { log *zap.Logger } -func (s *VersionsByPrefixStreamImpl) Next(context.Context) (*data.NodeVersion, error) { +func (s *VersionsByPrefixStreamImpl) Next(ctx context.Context) (*data.NodeVersion, error) { if s.ended { return nil, io.EOF } for { if s.innerStream == nil { - node, err := s.getNodeFromMainStream() + node, err := s.getNodeFromMainStream(ctx) if err != nil { if errors.Is(err, io.EOF) { s.ended = true @@ -958,7 +958,7 @@ func (s *VersionsByPrefixStreamImpl) Next(context.Context) (*data.NodeVersion, e } } - nodeVersion, err := s.getNodeVersionFromInnerStream() + nodeVersion, err := s.getNodeVersionFromInnerStream(ctx) if err != nil { if errors.Is(err, io.EOF) { s.innerStream = nil @@ -974,9 +974,9 @@ func (s *VersionsByPrefixStreamImpl) Next(context.Context) (*data.NodeVersion, e } } -func (s *VersionsByPrefixStreamImpl) getNodeFromMainStream() (NodeResponse, error) { +func (s *VersionsByPrefixStreamImpl) getNodeFromMainStream(ctx context.Context) (NodeResponse, error) { for { - node, err := s.mainStream.Next() + node, err := s.mainStream.Next(ctx) if err != nil { if errors.Is(err, tree.ErrNodeNotFound) { return nil, io.EOF @@ -1007,9 +1007,9 @@ func (s *VersionsByPrefixStreamImpl) initInnerStream(node NodeResponse) (err err return nil } -func (s *VersionsByPrefixStreamImpl) getNodeVersionFromInnerStream() (*data.NodeVersion, error) { +func (s *VersionsByPrefixStreamImpl) getNodeVersionFromInnerStream(ctx context.Context) (*data.NodeVersion, error) { for { - node, err := s.innerStream.Next() + node, err := s.innerStream.Next(ctx) if err != nil { return nil, fmt.Errorf("inner stream: %w", err) } @@ -1721,7 +1721,7 @@ func (c *Tree) addVersion(ctx context.Context, bktInfo *data.BucketInfo, treeID node, err := c.getUnversioned(ctx, bktInfo, treeID, version.FilePath) if err == nil { - if err = c.service.MoveNode(ctx, bktInfo, treeID, node.ID, node.ParenID, meta); err != nil { + if err = c.service.MoveNode(ctx, bktInfo, treeID, node.ID, node.ParentID, meta); err != nil { return 0, err } diff --git a/pkg/service/tree/tree_client_in_memory.go b/pkg/service/tree/tree_client_in_memory.go index efa93c45..6d104fa0 100644 --- a/pkg/service/tree/tree_client_in_memory.go +++ b/pkg/service/tree/tree_client_in_memory.go @@ -269,7 +269,12 @@ type SubTreeStreamMemoryImpl struct { err error } -func (s *SubTreeStreamMemoryImpl) Next() (NodeResponse, error) { +func (s *SubTreeStreamMemoryImpl) Next(ctx context.Context) (NodeResponse, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } if s.err != nil { return nil, s.err } diff --git a/pkg/service/tree/tree_test.go b/pkg/service/tree/tree_test.go index 058e46a9..5b7b6656 100644 --- a/pkg/service/tree/tree_test.go +++ b/pkg/service/tree/tree_test.go @@ -2,6 +2,8 @@ package tree import ( "context" + "io" + "sort" "testing" "time" @@ -359,3 +361,81 @@ func TestSplitTreeMultiparts(t *testing.T) { require.NoError(t, err) require.Len(t, parts, 1) } + +func TestVersionsByPrefixStreamImpl_Next(t *testing.T) { + ctx := context.Background() + + memCli, err := NewTreeServiceClientMemory() + require.NoError(t, err) + treeService := NewTree(memCli, zaptest.NewLogger(t)) + bktInfo := &data.BucketInfo{ + CID: cidtest.ID(), + } + ownerID := usertest.ID() + created := time.Now() + versions := []*data.NodeVersion{ + { + BaseNodeVersion: data.BaseNodeVersion{ + ID: 1, + OID: oidtest.ID(), + FilePath: "foo", + Owner: &ownerID, + Created: &created, + }, + }, + { + BaseNodeVersion: data.BaseNodeVersion{ + ID: 2, + OID: oidtest.ID(), + FilePath: "bar", + Owner: &ownerID, + Created: &created, + }, + }, + { + BaseNodeVersion: data.BaseNodeVersion{ + ID: 3, + OID: oidtest.ID(), + FilePath: "test", + Owner: &ownerID, + Created: &created, + }, + }, + } + + for _, v := range versions { + _, err = treeService.AddVersion(ctx, bktInfo, v) + require.NoError(t, err) + } + + sort.Slice(versions, func(i, j int) bool { + return versions[i].FilePath < versions[j].FilePath + }) + + t.Run("basic", func(t *testing.T) { + stream, err := treeService.InitVersionsByPrefixStream(ctx, bktInfo, "", false) + require.NoError(t, err) + + for i := range len(versions) { + node, err := stream.Next(ctx) + require.NoError(t, err) + require.Equal(t, versions[i].ID, node.ID) + require.Equal(t, versions[i].FilePath, node.FilePath) + } + + node, err := stream.Next(ctx) + require.Nil(t, node) + require.ErrorIs(t, err, io.EOF) + }) + + t.Run("context cancel", func(t *testing.T) { + stream, err := treeService.InitVersionsByPrefixStream(ctx, bktInfo, "", false) + require.NoError(t, err) + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + node, err := stream.Next(cancelCtx) + require.Nil(t, node) + require.ErrorIs(t, err, context.Canceled) + }) +}