frostfs-node/pkg/services/object/search/search_test.go

428 lines
8.7 KiB
Go
Raw Permalink Normal View History

package searchsvc
import (
"context"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"strconv"
"testing"
clientcore "github.com/TrueCloudLab/frostfs-node/pkg/core/client"
netmapcore "github.com/TrueCloudLab/frostfs-node/pkg/core/netmap"
"github.com/TrueCloudLab/frostfs-node/pkg/network"
"github.com/TrueCloudLab/frostfs-node/pkg/services/object/util"
"github.com/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement"
"github.com/TrueCloudLab/frostfs-node/pkg/util/logger/test"
"github.com/TrueCloudLab/frostfs-sdk-go/container"
cid "github.com/TrueCloudLab/frostfs-sdk-go/container/id"
cidtest "github.com/TrueCloudLab/frostfs-sdk-go/container/id/test"
"github.com/TrueCloudLab/frostfs-sdk-go/netmap"
oid "github.com/TrueCloudLab/frostfs-sdk-go/object/id"
"github.com/stretchr/testify/require"
)
type idsErr struct {
ids []oid.ID
err error
}
type testStorage struct {
items map[string]idsErr
}
type testTraverserGenerator struct {
c container.Container
b map[uint64]placement.Builder
}
type testPlacementBuilder struct {
vectors map[string][][]netmap.NodeInfo
}
type testClientCache struct {
clients map[string]*testStorage
}
type simpleIDWriter struct {
ids []oid.ID
}
type testEpochReceiver uint64
func (e testEpochReceiver) currentEpoch() (uint64, error) {
return uint64(e), nil
}
func (s *simpleIDWriter) WriteIDs(ids []oid.ID) error {
s.ids = append(s.ids, ids...)
return nil
}
func newTestStorage() *testStorage {
return &testStorage{
items: make(map[string]idsErr),
}
}
func (g *testTraverserGenerator) generateTraverser(_ cid.ID, epoch uint64) (*placement.Traverser, error) {
return placement.NewTraverser(
placement.ForContainer(g.c),
placement.UseBuilder(g.b[epoch]),
placement.WithoutSuccessTracking(),
)
}
func (p *testPlacementBuilder) BuildPlacement(cnr cid.ID, obj *oid.ID, _ netmap.PlacementPolicy) ([][]netmap.NodeInfo, error) {
var addr oid.Address
addr.SetContainer(cnr)
if obj != nil {
addr.SetObject(*obj)
}
vs, ok := p.vectors[addr.EncodeToString()]
if !ok {
return nil, errors.New("vectors for address not found")
}
res := make([][]netmap.NodeInfo, len(vs))
copy(res, vs)
return res, nil
}
func (c *testClientCache) get(info clientcore.NodeInfo) (searchClient, error) {
v, ok := c.clients[network.StringifyGroup(info.AddressGroup())]
if !ok {
return nil, errors.New("could not construct client")
}
return v, nil
}
func (s *testStorage) search(exec *execCtx) ([]oid.ID, error) {
v, ok := s.items[exec.containerID().EncodeToString()]
if !ok {
return nil, nil
}
return v.ids, v.err
}
func (c *testStorage) searchObjects(exec *execCtx, _ clientcore.NodeInfo) ([]oid.ID, error) {
v, ok := c.items[exec.containerID().EncodeToString()]
if !ok {
return nil, nil
}
return v.ids, v.err
}
func (c *testStorage) addResult(addr cid.ID, ids []oid.ID, err error) {
c.items[addr.EncodeToString()] = idsErr{
ids: ids,
err: err,
}
}
func testSHA256() (cs [sha256.Size]byte) {
rand.Read(cs[:])
return cs
}
func generateIDs(num int) []oid.ID {
res := make([]oid.ID, num)
for i := 0; i < num; i++ {
res[i].SetSHA256(testSHA256())
}
return res
}
func TestGetLocalOnly(t *testing.T) {
ctx := context.Background()
newSvc := func(storage *testStorage) *Service {
svc := &Service{cfg: new(cfg)}
svc.log = test.NewLogger(false)
svc.localStorage = storage
return svc
}
newPrm := func(cnr cid.ID, w IDListWriter) Prm {
p := Prm{}
p.WithContainerID(cnr)
p.SetWriter(w)
p.common = new(util.CommonPrm).WithLocalOnly(true)
return p
}
t.Run("OK", func(t *testing.T) {
storage := newTestStorage()
svc := newSvc(storage)
cnr := cidtest.ID()
ids := generateIDs(10)
storage.addResult(cnr, ids, nil)
w := new(simpleIDWriter)
p := newPrm(cnr, w)
err := svc.Search(ctx, p)
require.NoError(t, err)
require.Equal(t, ids, w.ids)
})
t.Run("FAIL", func(t *testing.T) {
storage := newTestStorage()
svc := newSvc(storage)
cnr := cidtest.ID()
testErr := errors.New("any error")
storage.addResult(cnr, nil, testErr)
w := new(simpleIDWriter)
p := newPrm(cnr, w)
err := svc.Search(ctx, p)
require.ErrorIs(t, err, testErr)
})
}
func testNodeMatrix(t testing.TB, dim []int) ([][]netmap.NodeInfo, [][]string) {
mNodes := make([][]netmap.NodeInfo, len(dim))
mAddr := make([][]string, len(dim))
for i := range dim {
ns := make([]netmap.NodeInfo, dim[i])
as := make([]string, dim[i])
for j := 0; j < dim[i]; j++ {
a := fmt.Sprintf("/ip4/192.168.0.%s/tcp/%s",
strconv.Itoa(i),
strconv.Itoa(60000+j),
)
var ni netmap.NodeInfo
ni.SetNetworkEndpoints(a)
var na network.AddressGroup
err := na.FromIterator(netmapcore.Node(ni))
require.NoError(t, err)
as[j] = network.StringifyGroup(na)
ns[j] = ni
}
mNodes[i] = ns
mAddr[i] = as
}
return mNodes, mAddr
}
func TestGetRemoteSmall(t *testing.T) {
ctx := context.Background()
placementDim := []int{2}
rs := make([]netmap.ReplicaDescriptor, len(placementDim))
for i := range placementDim {
rs[i].SetNumberOfObjects(uint32(placementDim[i]))
}
var pp netmap.PlacementPolicy
pp.AddReplicas(rs...)
var cnr container.Container
cnr.SetPlacementPolicy(pp)
var id cid.ID
container.CalculateID(&id, cnr)
newSvc := func(b *testPlacementBuilder, c *testClientCache) *Service {
svc := &Service{cfg: new(cfg)}
svc.log = test.NewLogger(false)
svc.localStorage = newTestStorage()
const curEpoch = 13
svc.traverserGenerator = &testTraverserGenerator{
c: cnr,
b: map[uint64]placement.Builder{
curEpoch: b,
},
}
svc.clientConstructor = c
svc.currentEpochReceiver = testEpochReceiver(curEpoch)
return svc
}
newPrm := func(id cid.ID, w IDListWriter) Prm {
p := Prm{}
p.WithContainerID(id)
p.SetWriter(w)
p.common = new(util.CommonPrm).WithLocalOnly(false)
return p
}
t.Run("OK", func(t *testing.T) {
var addr oid.Address
addr.SetContainer(id)
ns, as := testNodeMatrix(t, placementDim)
builder := &testPlacementBuilder{
vectors: map[string][][]netmap.NodeInfo{
addr.EncodeToString(): ns,
},
}
c1 := newTestStorage()
ids1 := generateIDs(10)
c1.addResult(id, ids1, nil)
c2 := newTestStorage()
ids2 := generateIDs(10)
c2.addResult(id, ids2, nil)
svc := newSvc(builder, &testClientCache{
clients: map[string]*testStorage{
as[0][0]: c1,
as[0][1]: c2,
},
})
w := new(simpleIDWriter)
p := newPrm(id, w)
err := svc.Search(ctx, p)
require.NoError(t, err)
require.Len(t, w.ids, len(ids1)+len(ids2))
for _, id := range append(ids1, ids2...) {
require.Contains(t, w.ids, id)
}
})
}
func TestGetFromPastEpoch(t *testing.T) {
ctx := context.Background()
placementDim := []int{2, 2}
rs := make([]netmap.ReplicaDescriptor, len(placementDim))
for i := range placementDim {
rs[i].SetNumberOfObjects(uint32(placementDim[i]))
}
var pp netmap.PlacementPolicy
pp.AddReplicas(rs...)
var cnr container.Container
cnr.SetPlacementPolicy(pp)
var idCnr cid.ID
container.CalculateID(&idCnr, cnr)
var addr oid.Address
addr.SetContainer(idCnr)
ns, as := testNodeMatrix(t, placementDim)
c11 := newTestStorage()
ids11 := generateIDs(10)
c11.addResult(idCnr, ids11, nil)
c12 := newTestStorage()
ids12 := generateIDs(10)
c12.addResult(idCnr, ids12, nil)
c21 := newTestStorage()
ids21 := generateIDs(10)
c21.addResult(idCnr, ids21, nil)
c22 := newTestStorage()
ids22 := generateIDs(10)
c22.addResult(idCnr, ids22, nil)
svc := &Service{cfg: new(cfg)}
svc.log = test.NewLogger(false)
svc.localStorage = newTestStorage()
const curEpoch = 13
svc.traverserGenerator = &testTraverserGenerator{
c: cnr,
b: map[uint64]placement.Builder{
curEpoch: &testPlacementBuilder{
vectors: map[string][][]netmap.NodeInfo{
addr.EncodeToString(): ns[:1],
},
},
curEpoch - 1: &testPlacementBuilder{
vectors: map[string][][]netmap.NodeInfo{
addr.EncodeToString(): ns[1:],
},
},
},
}
svc.clientConstructor = &testClientCache{
clients: map[string]*testStorage{
as[0][0]: c11,
as[0][1]: c12,
as[1][0]: c21,
as[1][1]: c22,
},
}
svc.currentEpochReceiver = testEpochReceiver(curEpoch)
w := new(simpleIDWriter)
p := Prm{}
p.WithContainerID(idCnr)
p.SetWriter(w)
commonPrm := new(util.CommonPrm)
p.SetCommonParameters(commonPrm)
assertContains := func(idsList ...[]oid.ID) {
var sz int
for _, ids := range idsList {
sz += len(ids)
for _, id := range ids {
require.Contains(t, w.ids, id)
}
}
require.Len(t, w.ids, sz)
}
err := svc.Search(ctx, p)
require.NoError(t, err)
assertContains(ids11, ids12)
commonPrm.SetNetmapLookupDepth(1)
w = new(simpleIDWriter)
p.SetWriter(w)
err = svc.Search(ctx, p)
require.NoError(t, err)
assertContains(ids11, ids12, ids21, ids22)
}