package searchsvc import ( "context" "crypto/rand" "crypto/sha256" "errors" "fmt" "strconv" "testing" clientcore "github.com/TrueCloudLab/frostfs-node/pkg/core/client" netmapcore "github.com/TrueCloudLab/frostfs-node/pkg/core/netmap" "github.com/TrueCloudLab/frostfs-node/pkg/network" "github.com/TrueCloudLab/frostfs-node/pkg/services/object/util" "github.com/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement" "github.com/TrueCloudLab/frostfs-node/pkg/util/logger/test" "github.com/TrueCloudLab/frostfs-sdk-go/container" cid "github.com/TrueCloudLab/frostfs-sdk-go/container/id" cidtest "github.com/TrueCloudLab/frostfs-sdk-go/container/id/test" "github.com/TrueCloudLab/frostfs-sdk-go/netmap" oid "github.com/TrueCloudLab/frostfs-sdk-go/object/id" "github.com/stretchr/testify/require" ) type idsErr struct { ids []oid.ID err error } type testStorage struct { items map[string]idsErr } type testTraverserGenerator struct { c container.Container b map[uint64]placement.Builder } type testPlacementBuilder struct { vectors map[string][][]netmap.NodeInfo } type testClientCache struct { clients map[string]*testStorage } type simpleIDWriter struct { ids []oid.ID } type testEpochReceiver uint64 func (e testEpochReceiver) Epoch() (uint64, error) { return uint64(e), nil } func (s *simpleIDWriter) WriteIDs(ids []oid.ID) error { s.ids = append(s.ids, ids...) return nil } func newTestStorage() *testStorage { return &testStorage{ items: make(map[string]idsErr), } } func (g *testTraverserGenerator) generateTraverser(_ cid.ID, epoch uint64) (*placement.Traverser, error) { return placement.NewTraverser( placement.ForContainer(g.c), placement.UseBuilder(g.b[epoch]), placement.WithoutSuccessTracking(), ) } func (p *testPlacementBuilder) BuildPlacement(cnr cid.ID, obj *oid.ID, _ netmap.PlacementPolicy) ([][]netmap.NodeInfo, error) { var addr oid.Address addr.SetContainer(cnr) if obj != nil { addr.SetObject(*obj) } vs, ok := p.vectors[addr.EncodeToString()] if !ok { return nil, errors.New("vectors for address not found") } res := make([][]netmap.NodeInfo, len(vs)) copy(res, vs) return res, nil } func (c *testClientCache) get(info clientcore.NodeInfo) (searchClient, error) { v, ok := c.clients[network.StringifyGroup(info.AddressGroup())] if !ok { return nil, errors.New("could not construct client") } return v, nil } func (s *testStorage) search(exec *execCtx) ([]oid.ID, error) { v, ok := s.items[exec.prm.cnr.EncodeToString()] if !ok { return nil, nil } return v.ids, v.err } func (c *testStorage) searchObjects(exec *execCtx, _ clientcore.NodeInfo) ([]oid.ID, error) { v, ok := c.items[exec.prm.cnr.EncodeToString()] if !ok { return nil, nil } return v.ids, v.err } func (c *testStorage) addResult(addr cid.ID, ids []oid.ID, err error) { c.items[addr.EncodeToString()] = idsErr{ ids: ids, err: err, } } func testSHA256() (cs [sha256.Size]byte) { rand.Read(cs[:]) return cs } func generateIDs(num int) []oid.ID { res := make([]oid.ID, num) for i := 0; i < num; i++ { res[i].SetSHA256(testSHA256()) } return res } func TestGetLocalOnly(t *testing.T) { ctx := context.Background() newSvc := func(storage *testStorage) *Service { svc := &Service{cfg: new(cfg)} svc.log = test.NewLogger(false) svc.localStorage = storage return svc } newPrm := func(cnr cid.ID, w IDListWriter) Prm { p := Prm{} p.WithContainerID(cnr) p.SetWriter(w) p.common = new(util.CommonPrm).WithLocalOnly(true) return p } t.Run("OK", func(t *testing.T) { storage := newTestStorage() svc := newSvc(storage) cnr := cidtest.ID() ids := generateIDs(10) storage.addResult(cnr, ids, nil) w := new(simpleIDWriter) p := newPrm(cnr, w) err := svc.Search(ctx, p) require.NoError(t, err) require.Equal(t, ids, w.ids) }) t.Run("FAIL", func(t *testing.T) { storage := newTestStorage() svc := newSvc(storage) cnr := cidtest.ID() testErr := errors.New("any error") storage.addResult(cnr, nil, testErr) w := new(simpleIDWriter) p := newPrm(cnr, w) err := svc.Search(ctx, p) require.ErrorIs(t, err, testErr) }) } func testNodeMatrix(t testing.TB, dim []int) ([][]netmap.NodeInfo, [][]string) { mNodes := make([][]netmap.NodeInfo, len(dim)) mAddr := make([][]string, len(dim)) for i := range dim { ns := make([]netmap.NodeInfo, dim[i]) as := make([]string, dim[i]) for j := 0; j < dim[i]; j++ { a := fmt.Sprintf("/ip4/192.168.0.%s/tcp/%s", strconv.Itoa(i), strconv.Itoa(60000+j), ) var ni netmap.NodeInfo ni.SetNetworkEndpoints(a) var na network.AddressGroup err := na.FromIterator(netmapcore.Node(ni)) require.NoError(t, err) as[j] = network.StringifyGroup(na) ns[j] = ni } mNodes[i] = ns mAddr[i] = as } return mNodes, mAddr } func TestGetRemoteSmall(t *testing.T) { ctx := context.Background() placementDim := []int{2} rs := make([]netmap.ReplicaDescriptor, len(placementDim)) for i := range placementDim { rs[i].SetNumberOfObjects(uint32(placementDim[i])) } var pp netmap.PlacementPolicy pp.AddReplicas(rs...) var cnr container.Container cnr.SetPlacementPolicy(pp) var id cid.ID container.CalculateID(&id, cnr) newSvc := func(b *testPlacementBuilder, c *testClientCache) *Service { svc := &Service{cfg: new(cfg)} svc.log = test.NewLogger(false) svc.localStorage = newTestStorage() const curEpoch = 13 svc.traverserGenerator = &testTraverserGenerator{ c: cnr, b: map[uint64]placement.Builder{ curEpoch: b, }, } svc.clientConstructor = c svc.epochSource = testEpochReceiver(curEpoch) return svc } newPrm := func(id cid.ID, w IDListWriter) Prm { p := Prm{} p.WithContainerID(id) p.SetWriter(w) p.common = new(util.CommonPrm).WithLocalOnly(false) return p } t.Run("OK", func(t *testing.T) { var addr oid.Address addr.SetContainer(id) ns, as := testNodeMatrix(t, placementDim) builder := &testPlacementBuilder{ vectors: map[string][][]netmap.NodeInfo{ addr.EncodeToString(): ns, }, } c1 := newTestStorage() ids1 := generateIDs(10) c1.addResult(id, ids1, nil) c2 := newTestStorage() ids2 := generateIDs(10) c2.addResult(id, ids2, nil) svc := newSvc(builder, &testClientCache{ clients: map[string]*testStorage{ as[0][0]: c1, as[0][1]: c2, }, }) w := new(simpleIDWriter) p := newPrm(id, w) err := svc.Search(ctx, p) require.NoError(t, err) require.Len(t, w.ids, len(ids1)+len(ids2)) for _, id := range append(ids1, ids2...) { require.Contains(t, w.ids, id) } }) } func TestGetFromPastEpoch(t *testing.T) { ctx := context.Background() placementDim := []int{2, 2} rs := make([]netmap.ReplicaDescriptor, len(placementDim)) for i := range placementDim { rs[i].SetNumberOfObjects(uint32(placementDim[i])) } var pp netmap.PlacementPolicy pp.AddReplicas(rs...) var cnr container.Container cnr.SetPlacementPolicy(pp) var idCnr cid.ID container.CalculateID(&idCnr, cnr) var addr oid.Address addr.SetContainer(idCnr) ns, as := testNodeMatrix(t, placementDim) c11 := newTestStorage() ids11 := generateIDs(10) c11.addResult(idCnr, ids11, nil) c12 := newTestStorage() ids12 := generateIDs(10) c12.addResult(idCnr, ids12, nil) c21 := newTestStorage() ids21 := generateIDs(10) c21.addResult(idCnr, ids21, nil) c22 := newTestStorage() ids22 := generateIDs(10) c22.addResult(idCnr, ids22, nil) svc := &Service{cfg: new(cfg)} svc.log = test.NewLogger(false) svc.localStorage = newTestStorage() const curEpoch = 13 svc.traverserGenerator = &testTraverserGenerator{ c: cnr, b: map[uint64]placement.Builder{ curEpoch: &testPlacementBuilder{ vectors: map[string][][]netmap.NodeInfo{ addr.EncodeToString(): ns[:1], }, }, curEpoch - 1: &testPlacementBuilder{ vectors: map[string][][]netmap.NodeInfo{ addr.EncodeToString(): ns[1:], }, }, }, } svc.clientConstructor = &testClientCache{ clients: map[string]*testStorage{ as[0][0]: c11, as[0][1]: c12, as[1][0]: c21, as[1][1]: c22, }, } svc.epochSource = testEpochReceiver(curEpoch) w := new(simpleIDWriter) p := Prm{} p.WithContainerID(idCnr) p.SetWriter(w) commonPrm := new(util.CommonPrm) p.SetCommonParameters(commonPrm) assertContains := func(idsList ...[]oid.ID) { var sz int for _, ids := range idsList { sz += len(ids) for _, id := range ids { require.Contains(t, w.ids, id) } } require.Len(t, w.ids, sz) } err := svc.Search(ctx, p) require.NoError(t, err) assertContains(ids11, ids12) commonPrm.SetNetmapLookupDepth(1) w = new(simpleIDWriter) p.SetWriter(w) err = svc.Search(ctx, p) require.NoError(t, err) assertContains(ids11, ids12, ids21, ids22) }