[#569] Support context cancellation in tree node streaming

Signed-off-by: Nikita Zinkevich <n.zinkevich@yadro.com>
This commit is contained in:
Nikita Zinkevich 2024-12-10 13:28:23 +03:00 committed by Alexey Vanin
parent 16eb289929
commit d46f1d3bfa
5 changed files with 105 additions and 15 deletions

View file

@ -62,7 +62,7 @@ func (e ExtendedObjectInfo) Version() string {
// Basically used for "system" object. // Basically used for "system" object.
type BaseNodeVersion struct { type BaseNodeVersion struct {
ID uint64 ID uint64
ParenID uint64 ParentID uint64
OID oid.ID OID oid.ID
Timestamp uint64 Timestamp uint64
Size uint64 Size uint64

View file

@ -152,7 +152,12 @@ type SubTreeStreamImpl struct {
const bufSize = 1000 const bufSize = 1000
func (s *SubTreeStreamImpl) Next() (tree.NodeResponse, error) { func (s *SubTreeStreamImpl) Next(ctx context.Context) (tree.NodeResponse, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if s.index != -1 { if s.index != -1 {
node := s.buffer[s.index] node := s.buffer[s.index]
s.index++ s.index++
@ -177,7 +182,7 @@ func (s *SubTreeStreamImpl) Next() (tree.NodeResponse, error) {
s.index = 0 s.index = 0
} }
return s.Next() return s.Next(ctx)
} }
func (w *PoolWrapper) GetSubTreeStream(ctx context.Context, bktInfo *data.BucketInfo, treeID string, rootID []uint64, depth uint32) (tree.SubTreeStream, error) { func (w *PoolWrapper) GetSubTreeStream(ctx context.Context, bktInfo *data.BucketInfo, treeID string, rootID []uint64, depth uint32) (tree.SubTreeStream, error) {

View file

@ -43,7 +43,7 @@ type (
} }
SubTreeStream interface { SubTreeStream interface {
Next() (NodeResponse, error) Next(ctx context.Context) (NodeResponse, error)
} }
treeNode struct { treeNode struct {
@ -242,7 +242,7 @@ func newNodeVersionFromTreeNode(log *zap.Logger, filePath string, treeNode *tree
version := &data.NodeVersion{ version := &data.NodeVersion{
BaseNodeVersion: data.BaseNodeVersion{ BaseNodeVersion: data.BaseNodeVersion{
ID: treeNode.ID[0], ID: treeNode.ID[0],
ParenID: treeNode.ParentID[0], ParentID: treeNode.ParentID[0],
OID: treeNode.ObjID, OID: treeNode.ObjID,
Timestamp: treeNode.TimeStamp[0], Timestamp: treeNode.TimeStamp[0],
ETag: eTag, ETag: eTag,
@ -891,7 +891,7 @@ type DummySubTreeStream struct {
read bool read bool
} }
func (s *DummySubTreeStream) Next() (NodeResponse, error) { func (s *DummySubTreeStream) Next(context.Context) (NodeResponse, error) {
if s.read { if s.read {
return nil, io.EOF return nil, io.EOF
} }
@ -935,14 +935,14 @@ type VersionsByPrefixStreamImpl struct {
log *zap.Logger log *zap.Logger
} }
func (s *VersionsByPrefixStreamImpl) Next(context.Context) (*data.NodeVersion, error) { func (s *VersionsByPrefixStreamImpl) Next(ctx context.Context) (*data.NodeVersion, error) {
if s.ended { if s.ended {
return nil, io.EOF return nil, io.EOF
} }
for { for {
if s.innerStream == nil { if s.innerStream == nil {
node, err := s.getNodeFromMainStream() node, err := s.getNodeFromMainStream(ctx)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
s.ended = true s.ended = true
@ -958,7 +958,7 @@ func (s *VersionsByPrefixStreamImpl) Next(context.Context) (*data.NodeVersion, e
} }
} }
nodeVersion, err := s.getNodeVersionFromInnerStream() nodeVersion, err := s.getNodeVersionFromInnerStream(ctx)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
s.innerStream = nil s.innerStream = nil
@ -974,9 +974,9 @@ func (s *VersionsByPrefixStreamImpl) Next(context.Context) (*data.NodeVersion, e
} }
} }
func (s *VersionsByPrefixStreamImpl) getNodeFromMainStream() (NodeResponse, error) { func (s *VersionsByPrefixStreamImpl) getNodeFromMainStream(ctx context.Context) (NodeResponse, error) {
for { for {
node, err := s.mainStream.Next() node, err := s.mainStream.Next(ctx)
if err != nil { if err != nil {
if errors.Is(err, tree.ErrNodeNotFound) { if errors.Is(err, tree.ErrNodeNotFound) {
return nil, io.EOF return nil, io.EOF
@ -1007,9 +1007,9 @@ func (s *VersionsByPrefixStreamImpl) initInnerStream(node NodeResponse) (err err
return nil return nil
} }
func (s *VersionsByPrefixStreamImpl) getNodeVersionFromInnerStream() (*data.NodeVersion, error) { func (s *VersionsByPrefixStreamImpl) getNodeVersionFromInnerStream(ctx context.Context) (*data.NodeVersion, error) {
for { for {
node, err := s.innerStream.Next() node, err := s.innerStream.Next(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("inner stream: %w", err) return nil, fmt.Errorf("inner stream: %w", err)
} }
@ -1721,7 +1721,7 @@ func (c *Tree) addVersion(ctx context.Context, bktInfo *data.BucketInfo, treeID
node, err := c.getUnversioned(ctx, bktInfo, treeID, version.FilePath) node, err := c.getUnversioned(ctx, bktInfo, treeID, version.FilePath)
if err == nil { if err == nil {
if err = c.service.MoveNode(ctx, bktInfo, treeID, node.ID, node.ParenID, meta); err != nil { if err = c.service.MoveNode(ctx, bktInfo, treeID, node.ID, node.ParentID, meta); err != nil {
return 0, err return 0, err
} }

View file

@ -269,7 +269,12 @@ type SubTreeStreamMemoryImpl struct {
err error err error
} }
func (s *SubTreeStreamMemoryImpl) Next() (NodeResponse, error) { func (s *SubTreeStreamMemoryImpl) Next(ctx context.Context) (NodeResponse, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if s.err != nil { if s.err != nil {
return nil, s.err return nil, s.err
} }

View file

@ -2,6 +2,8 @@ package tree
import ( import (
"context" "context"
"io"
"sort"
"testing" "testing"
"time" "time"
@ -359,3 +361,81 @@ func TestSplitTreeMultiparts(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, parts, 1) require.Len(t, parts, 1)
} }
func TestVersionsByPrefixStreamImpl_Next(t *testing.T) {
ctx := context.Background()
memCli, err := NewTreeServiceClientMemory()
require.NoError(t, err)
treeService := NewTree(memCli, zaptest.NewLogger(t))
bktInfo := &data.BucketInfo{
CID: cidtest.ID(),
}
ownerID := usertest.ID()
created := time.Now()
versions := []*data.NodeVersion{
{
BaseNodeVersion: data.BaseNodeVersion{
ID: 1,
OID: oidtest.ID(),
FilePath: "foo",
Owner: &ownerID,
Created: &created,
},
},
{
BaseNodeVersion: data.BaseNodeVersion{
ID: 2,
OID: oidtest.ID(),
FilePath: "bar",
Owner: &ownerID,
Created: &created,
},
},
{
BaseNodeVersion: data.BaseNodeVersion{
ID: 3,
OID: oidtest.ID(),
FilePath: "test",
Owner: &ownerID,
Created: &created,
},
},
}
for _, v := range versions {
_, err = treeService.AddVersion(ctx, bktInfo, v)
require.NoError(t, err)
}
sort.Slice(versions, func(i, j int) bool {
return versions[i].FilePath < versions[j].FilePath
})
t.Run("basic", func(t *testing.T) {
stream, err := treeService.InitVersionsByPrefixStream(ctx, bktInfo, "", false)
require.NoError(t, err)
for i := range len(versions) {
node, err := stream.Next(ctx)
require.NoError(t, err)
require.Equal(t, versions[i].ID, node.ID)
require.Equal(t, versions[i].FilePath, node.FilePath)
}
node, err := stream.Next(ctx)
require.Nil(t, node)
require.ErrorIs(t, err, io.EOF)
})
t.Run("context cancel", func(t *testing.T) {
stream, err := treeService.InitVersionsByPrefixStream(ctx, bktInfo, "", false)
require.NoError(t, err)
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
node, err := stream.Next(cancelCtx)
require.Nil(t, node)
require.ErrorIs(t, err, context.Canceled)
})
}