diff --git a/netmap/subnet.go b/netmap/subnet.go new file mode 100644 index 00000000..09b7a57f --- /dev/null +++ b/netmap/subnet.go @@ -0,0 +1,68 @@ +package netmap + +import ( + "errors" + + "github.com/nspcc-dev/neofs-api-go/v2/netmap" + "github.com/nspcc-dev/neofs-api-go/v2/refs" + subnetid "github.com/nspcc-dev/neofs-sdk-go/subnet/id" +) + +// EnterSubnet writes to NodeInfo the intention to enter the subnet. Must not be called on nil. +// Zero NodeInfo belongs to zero subnet. +func (i *NodeInfo) EnterSubnet(id subnetid.ID) { + var ( + idv2 refs.SubnetID + info netmap.NodeSubnetInfo + ) + + id.WriteToV2(&idv2) + + info.SetID(&idv2) + info.SetEntryFlag(true) + + netmap.WriteSubnetInfo((*netmap.NodeInfo)(i), info) +} + +// ErrRemoveSubnet is returned when a node needs to leave the subnet. +var ErrRemoveSubnet = netmap.ErrRemoveSubnet + +// IterateSubnets iterates over all subnets the node belongs to and passes the IDs to f. +// Must not be called on nil. Handler must not be nil. +// +// If f returns ErrRemoveSubnet, then removes subnet entry. Note that this leads to an instant mutation of NodeInfo. +// Breaks on any other non-nil error and returns it. +// +// Returns an error if subnet incorrectly enabled/disabled. +// Returns an error if the node is not included in any subnet by the end of the loop. +func (i *NodeInfo) IterateSubnets(f func(subnetid.ID) error) error { + var id subnetid.ID + + return netmap.IterateSubnets((*netmap.NodeInfo)(i), func(idv2 refs.SubnetID) error { + id.FromV2(idv2) + + err := f(id) + if errors.Is(err, ErrRemoveSubnet) { + return netmap.ErrRemoveSubnet + } + + return err + }) +} + +var errAbortSubnetIter = errors.New("abort subnet iterator") + +// BelongsToSubnet checks if node belongs to subnet by ID. +// +// Function is NPE-safe: nil NodeInfo always belongs to zero subnet only. +func BelongsToSubnet(node *NodeInfo, id subnetid.ID) bool { + err := node.IterateSubnets(func(id_ subnetid.ID) error { + if id.Equals(&id_) { + return errAbortSubnetIter + } + + return nil + }) + + return errors.Is(err, errAbortSubnetIter) +} diff --git a/netmap/subnet_test.go b/netmap/subnet_test.go new file mode 100644 index 00000000..cb1c0c47 --- /dev/null +++ b/netmap/subnet_test.go @@ -0,0 +1,102 @@ +package netmap_test + +import ( + "testing" + + "github.com/nspcc-dev/neofs-sdk-go/netmap" + subnetid "github.com/nspcc-dev/neofs-sdk-go/subnet/id" + "github.com/stretchr/testify/require" +) + +func TestNodeInfoSubnets(t *testing.T) { + t.Run("enter subnet", func(t *testing.T) { + var id subnetid.ID + + id.SetNumber(13) + + var node netmap.NodeInfo + + node.EnterSubnet(id) + + mIDs := make(map[string]struct{}) + + err := node.IterateSubnets(func(id subnetid.ID) error { + mIDs[id.String()] = struct{}{} + return nil + }) + + require.NoError(t, err) + + _, ok := mIDs[id.String()] + require.True(t, ok) + }) + + t.Run("iterate with removal", func(t *testing.T) { + t.Run("not last", func(t *testing.T) { + var id, idrm subnetid.ID + + id.SetNumber(13) + idrm.SetNumber(23) + + var node netmap.NodeInfo + + node.EnterSubnet(id) + node.EnterSubnet(idrm) + + err := node.IterateSubnets(func(id subnetid.ID) error { + if subnetid.IsZero(id) || id.Equals(&idrm) { + return netmap.ErrRemoveSubnet + } + + return nil + }) + + require.NoError(t, err) + + mIDs := make(map[string]struct{}) + + err = node.IterateSubnets(func(id subnetid.ID) error { + mIDs[id.String()] = struct{}{} + return nil + }) + + require.NoError(t, err) + + var zeroID subnetid.ID + + _, ok := mIDs[zeroID.String()] + require.False(t, ok) + + _, ok = mIDs[idrm.String()] + require.False(t, ok) + + _, ok = mIDs[id.String()] + require.True(t, ok) + }) + + t.Run("last", func(t *testing.T) { + var node netmap.NodeInfo + + err := node.IterateSubnets(func(id subnetid.ID) error { + return netmap.ErrRemoveSubnet + }) + + require.Error(t, err) + }) + }) +} + +func TestBelongsToSubnet(t *testing.T) { + var id, idMiss, idZero subnetid.ID + + id.SetNumber(13) + idMiss.SetNumber(23) + + var node netmap.NodeInfo + + node.EnterSubnet(id) + + require.True(t, netmap.BelongsToSubnet(&node, idZero)) + require.True(t, netmap.BelongsToSubnet(&node, id)) + require.False(t, netmap.BelongsToSubnet(&node, idMiss)) +}