netmap: Use CID to prevent incorrect usage in netmap

Signed-off-by: Evgenii Baidakov <evgenii@nspcc.io>
This commit is contained in:
Evgenii Baidakov 2023-04-18 10:56:09 +04:00
parent c6bda422fc
commit 4b6965f209
No known key found for this signature in database
GPG key ID: 8733EE3D72CDB4DE
5 changed files with 59 additions and 18 deletions

View file

@ -6,6 +6,8 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
oid "github.com/nspcc-dev/neofs-sdk-go/object/id"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -37,6 +39,20 @@ func compareNodes(t testing.TB, expected [][]int, nodes nodes, actual [][]NodeIn
} }
} }
func compareNodesIgnoreOrder(t testing.TB, expected [][]int, nodes nodes, actual [][]NodeInfo) {
require.Equal(t, len(expected), len(actual))
for i := range expected {
require.Equal(t, len(expected[i]), len(actual[i]))
var expectedNodes []NodeInfo
for _, index := range expected[i] {
expectedNodes = append(expectedNodes, nodes[index])
}
require.ElementsMatch(t, expectedNodes, actual[i])
}
}
func TestPlacementPolicy_Interopability(t *testing.T) { func TestPlacementPolicy_Interopability(t *testing.T) {
const testsDir = "./json_tests" const testsDir = "./json_tests"
@ -62,7 +78,10 @@ func TestPlacementPolicy_Interopability(t *testing.T) {
for name, tt := range tc.Tests { for name, tt := range tc.Tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
v, err := nm.ContainerNodes(tt.Policy, tt.Pivot) var pivot cid.ID
copy(pivot[:], tt.Pivot)
v, err := nm.ContainerNodes(tt.Policy, pivot)
if tt.Result == nil { if tt.Result == nil {
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), tt.Error) require.Contains(t, err.Error(), tt.Error)
@ -70,10 +89,13 @@ func TestPlacementPolicy_Interopability(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, srcNodes, tc.Nodes) require.Equal(t, srcNodes, tc.Nodes)
compareNodes(t, tt.Result, tc.Nodes, v) compareNodesIgnoreOrder(t, tt.Result, tc.Nodes, v)
if tt.Placement.Result != nil { if tt.Placement.Result != nil {
res, err := nm.PlacementVectors(v, tt.Placement.Pivot) var placementPivot oid.ID
copy(placementPivot[:], tt.Placement.Pivot)
res, err := nm.PlacementVectors(v, placementPivot)
require.NoError(t, err) require.NoError(t, err)
compareNodes(t, tt.Placement.Result, tc.Nodes, res) compareNodes(t, tt.Placement.Result, tc.Nodes, res)
require.Equal(t, srcNodes, tc.Nodes) require.Equal(t, srcNodes, tc.Nodes)
@ -108,11 +130,14 @@ func BenchmarkPlacementPolicyInteropability(b *testing.B) {
for name, tt := range tc.Tests { for name, tt := range tc.Tests {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
var pivot cid.ID
copy(pivot[:], tt.Pivot)
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
b.StartTimer() b.StartTimer()
v, err := nm.ContainerNodes(tt.Policy, tt.Pivot) v, err := nm.ContainerNodes(tt.Policy, pivot)
b.StopTimer() b.StopTimer()
if tt.Result == nil { if tt.Result == nil {
require.Error(b, err) require.Error(b, err)
@ -120,11 +145,14 @@ func BenchmarkPlacementPolicyInteropability(b *testing.B) {
} else { } else {
require.NoError(b, err) require.NoError(b, err)
compareNodes(b, tt.Result, tc.Nodes, v) compareNodesIgnoreOrder(b, tt.Result, tc.Nodes, v)
if tt.Placement.Result != nil { if tt.Placement.Result != nil {
var placementPivot oid.ID
copy(placementPivot[:], tt.Placement.Pivot)
b.StartTimer() b.StartTimer()
res, err := nm.PlacementVectors(v, tt.Placement.Pivot) res, err := nm.PlacementVectors(v, placementPivot)
b.StopTimer() b.StopTimer()
require.NoError(b, err) require.NoError(b, err)
compareNodes(b, tt.Placement.Result, tc.Nodes, res) compareNodes(b, tt.Placement.Result, tc.Nodes, res)
@ -150,11 +178,14 @@ func BenchmarkManySelects(b *testing.B) {
var nm NetMap var nm NetMap
nm.SetNodes(tc.Nodes) nm.SetNodes(tc.Nodes)
var pivot cid.ID
copy(pivot[:], tt.Pivot)
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err = nm.ContainerNodes(tt.Policy, tt.Pivot) _, err = nm.ContainerNodes(tt.Policy, pivot)
if err != nil { if err != nil {
b.FailNow() b.FailNow()
} }

View file

@ -183,9 +183,9 @@
"policy": {"replicas":[{"count":1,"selector":"SameRU"},{"count":1,"selector":"DistinctRU"},{"count":1,"selector":"Good"},{"count":1,"selector":"Main"}],"containerBackupFactor":2,"selectors":[{"name":"SameRU","count":2,"clause":"SAME","attribute":"City","filter":"FromRU"},{"name":"DistinctRU","count":2,"clause":"DISTINCT","attribute":"City","filter":"FromRU"},{"name":"Good","count":2,"clause":"DISTINCT","attribute":"Country","filter":"Good"},{"name":"Main","count":3,"clause":"DISTINCT","attribute":"Country","filter":"*"}],"filters":[{"name":"FromRU","key":"Country","op":"EQ","value":"Russia"},{"name":"Good","key":"Rating","op":"GE","value":"4"}]}, "policy": {"replicas":[{"count":1,"selector":"SameRU"},{"count":1,"selector":"DistinctRU"},{"count":1,"selector":"Good"},{"count":1,"selector":"Main"}],"containerBackupFactor":2,"selectors":[{"name":"SameRU","count":2,"clause":"SAME","attribute":"City","filter":"FromRU"},{"name":"DistinctRU","count":2,"clause":"DISTINCT","attribute":"City","filter":"FromRU"},{"name":"Good","count":2,"clause":"DISTINCT","attribute":"Country","filter":"Good"},{"name":"Main","count":3,"clause":"DISTINCT","attribute":"Country","filter":"*"}],"filters":[{"name":"FromRU","key":"Country","op":"EQ","value":"Russia"},{"name":"Good","key":"Rating","op":"GE","value":"4"}]},
"result": [ "result": [
[0, 5, 9, 10], [0, 5, 9, 10],
[2, 6, 0, 5], [0, 5, 2, 6],
[1, 8, 2, 5], [1, 8, 2, 5],
[3, 4, 1, 7, 0, 2] [0, 2, 1, 7, 3, 4]
] ]
} }
} }

View file

@ -321,10 +321,10 @@
4 4
], ],
[ [
8,
12,
5, 5,
10 10,
8,
12
] ]
] ]
} }

View file

@ -1,10 +1,13 @@
package netmap package netmap
import ( import (
"crypto/sha256"
"fmt" "fmt"
"github.com/nspcc-dev/hrw" "github.com/nspcc-dev/hrw"
"github.com/nspcc-dev/neofs-api-go/v2/netmap" "github.com/nspcc-dev/neofs-api-go/v2/netmap"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
oid "github.com/nspcc-dev/neofs-sdk-go/object/id"
) )
// NetMap represents NeoFS network map. It includes information about all // NetMap represents NeoFS network map. It includes information about all
@ -140,11 +143,14 @@ func flattenNodes(ns []nodes) nodes {
} }
// PlacementVectors sorts container nodes returned by ContainerNodes method // PlacementVectors sorts container nodes returned by ContainerNodes method
// and returns placement vectors for the entity identified by the given pivot. // and returns placement vectors for the entity identified by the given object id.
// For example, in order to build node list to store the object, binary-encoded // For example, in order to build node list to store the object, binary-encoded
// object identifier can be used as pivot. Result is deterministic for // object identifier can be used as pivot. Result is deterministic for
// the fixed NetMap and parameters. // the fixed NetMap and parameters.
func (m NetMap) PlacementVectors(vectors [][]NodeInfo, pivot []byte) ([][]NodeInfo, error) { func (m NetMap) PlacementVectors(vectors [][]NodeInfo, objectID oid.ID) ([][]NodeInfo, error) {
pivot := make([]byte, sha256.Size)
objectID.Encode(pivot)
h := hrw.Hash(pivot) h := hrw.Hash(pivot)
wf := defaultWeightFunc(m.nodes) wf := defaultWeightFunc(m.nodes)
result := make([][]NodeInfo, len(vectors)) result := make([][]NodeInfo, len(vectors))
@ -166,11 +172,14 @@ func (m NetMap) PlacementVectors(vectors [][]NodeInfo, pivot []byte) ([][]NodeIn
// the fixed NetMap and parameters. // the fixed NetMap and parameters.
// //
// Result can be used in PlacementVectors. // Result can be used in PlacementVectors.
func (m NetMap) ContainerNodes(p PlacementPolicy, pivot []byte) ([][]NodeInfo, error) { func (m NetMap) ContainerNodes(p PlacementPolicy, containerID cid.ID) ([][]NodeInfo, error) {
c := newContext(m) c := newContext(m)
c.setPivot(pivot)
c.setCBF(p.backupFactor) c.setCBF(p.backupFactor)
pivot := make([]byte, sha256.Size)
containerID.Encode(pivot)
c.setPivot(pivot)
if err := c.processFilters(p); err != nil { if err := c.processFilters(p); err != nil {
return nil, err return nil, err
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/nspcc-dev/hrw" "github.com/nspcc-dev/hrw"
"github.com/nspcc-dev/neofs-api-go/v2/netmap" "github.com/nspcc-dev/neofs-api-go/v2/netmap"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -129,7 +130,7 @@ func BenchmarkPolicyHRWType(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := nm.ContainerNodes(p, []byte{1}) _, err := nm.ContainerNodes(p, cid.ID{1})
if err != nil { if err != nil {
b.Fatal() b.Fatal()
} }
@ -173,7 +174,7 @@ func TestPlacementPolicy_DeterministicOrder(t *testing.T) {
nm.SetNodes(nodeList) nm.SetNodes(nodeList)
getIndices := func(t *testing.T) (uint64, uint64) { getIndices := func(t *testing.T) (uint64, uint64) {
v, err := nm.ContainerNodes(p, []byte{1}) v, err := nm.ContainerNodes(p, cid.ID{1})
require.NoError(t, err) require.NoError(t, err)
nss := make([]nodes, len(v)) nss := make([]nodes, len(v))