diff --git a/cmd/frostfs-node/qos.go b/cmd/frostfs-node/qos.go index 9663fc6ae..6394b668b 100644 --- a/cmd/frostfs-node/qos.go +++ b/cmd/frostfs-node/qos.go @@ -43,6 +43,9 @@ func initQoSService(c *cfg) { func (s *cfgQoSService) AdjustIncomingTag(ctx context.Context, requestSignPublicKey []byte) context.Context { rawTag, defined := qosTagging.IOTagFromContext(ctx) if !defined { + if s.isInternalIOTagPublicKey(ctx, requestSignPublicKey) { + return qosTagging.ContextWithIOTag(ctx, qos.IOTagInternal.String()) + } return qosTagging.ContextWithIOTag(ctx, qos.IOTagClient.String()) } ioTag, err := qos.FromRawString(rawTag) @@ -73,20 +76,8 @@ func (s *cfgQoSService) AdjustIncomingTag(ctx context.Context, requestSignPublic s.logger.Debug(ctx, logs.FailedToValidateIncomingIOTag) return qosTagging.ContextWithIOTag(ctx, qos.IOTagClient.String()) case qos.IOTagInternal: - for _, pk := range s.allowedInternalPubs { - if bytes.Equal(pk, requestSignPublicKey) { - return ctx - } - } - nm, err := s.netmapSource.GetNetMap(ctx, 0) - if err != nil { - s.logger.Debug(ctx, logs.FailedToGetNetmapToAdjustIOTag, zap.Error(err)) - return qosTagging.ContextWithIOTag(ctx, qos.IOTagClient.String()) - } - for _, node := range nm.Nodes() { - if bytes.Equal(node.PublicKey(), requestSignPublicKey) { - return ctx - } + if s.isInternalIOTagPublicKey(ctx, requestSignPublicKey) { + return ctx } s.logger.Debug(ctx, logs.FailedToValidateIncomingIOTag) return qosTagging.ContextWithIOTag(ctx, qos.IOTagClient.String()) @@ -95,3 +86,23 @@ func (s *cfgQoSService) AdjustIncomingTag(ctx context.Context, requestSignPublic return qosTagging.ContextWithIOTag(ctx, qos.IOTagClient.String()) } } + +func (s *cfgQoSService) isInternalIOTagPublicKey(ctx context.Context, publicKey []byte) bool { + for _, pk := range s.allowedInternalPubs { + if bytes.Equal(pk, publicKey) { + return true + } + } + nm, err := s.netmapSource.GetNetMap(ctx, 0) + if err != nil { + s.logger.Debug(ctx, logs.FailedToGetNetmapToAdjustIOTag, zap.Error(err)) + return false + } + for _, node := range nm.Nodes() { + if bytes.Equal(node.PublicKey(), publicKey) { + return true + } + } + + return false +} diff --git a/cmd/frostfs-node/qos_test.go b/cmd/frostfs-node/qos_test.go new file mode 100644 index 000000000..971f9eebf --- /dev/null +++ b/cmd/frostfs-node/qos_test.go @@ -0,0 +1,226 @@ +package main + +import ( + "context" + "testing" + + "git.frostfs.info/TrueCloudLab/frostfs-node/internal/qos" + "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger/test" + utilTesting "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/testing" + "git.frostfs.info/TrueCloudLab/frostfs-qos/tagging" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/stretchr/testify/require" +) + +func TestQoSService_Client(t *testing.T) { + t.Parallel() + s, pk := testQoSServicePrepare(t) + t.Run("IO tag client defined", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagClient.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Request) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("no IO tag defined, signed with unknown key", func(t *testing.T) { + ctx := s.AdjustIncomingTag(context.Background(), pk.Request) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("no IO tag defined, signed with allowed critical key", func(t *testing.T) { + ctx := s.AdjustIncomingTag(context.Background(), pk.Critical) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("unknown IO tag, signed with unknown key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), "some IO tag we don't know") + ctx = s.AdjustIncomingTag(ctx, pk.Request) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("unknown IO tag, signed with netmap key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), "some IO tag we don't know") + ctx = s.AdjustIncomingTag(ctx, pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("unknown IO tag, signed with allowed internal key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), "some IO tag we don't know") + ctx = s.AdjustIncomingTag(ctx, pk.Internal) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("unknown IO tag, signed with allowed critical key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), "some IO tag we don't know") + ctx = s.AdjustIncomingTag(ctx, pk.Critical) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("IO tag internal defined, signed with unknown key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagInternal.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Request) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("IO tag internal defined, signed with allowed critical key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagInternal.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Critical) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("IO tag critical defined, signed with unknown key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Request) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("IO tag critical defined, signed with allowed internal key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Internal) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) +} + +func TestQoSService_Internal(t *testing.T) { + t.Parallel() + s, pk := testQoSServicePrepare(t) + t.Run("IO tag internal defined, signed with netmap key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagInternal.String()) + ctx = s.AdjustIncomingTag(ctx, pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagInternal.String(), tag) + }) + t.Run("IO tag internal defined, signed with allowed internal key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagInternal.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Internal) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagInternal.String(), tag) + }) + t.Run("no IO tag defined, signed with netmap key", func(t *testing.T) { + ctx := s.AdjustIncomingTag(context.Background(), pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagInternal.String(), tag) + }) + t.Run("no IO tag defined, signed with allowed internal key", func(t *testing.T) { + ctx := s.AdjustIncomingTag(context.Background(), pk.Internal) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagInternal.String(), tag) + }) +} + +func TestQoSService_Critical(t *testing.T) { + t.Parallel() + s, pk := testQoSServicePrepare(t) + t.Run("IO tag critical defined, signed with netmap key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) + ctx = s.AdjustIncomingTag(ctx, pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagCritical.String(), tag) + }) + t.Run("IO tag critical defined, signed with allowed critical key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) + ctx = s.AdjustIncomingTag(ctx, pk.Critical) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagCritical.String(), tag) + }) +} + +func TestQoSService_NetmapGetError(t *testing.T) { + t.Parallel() + s, pk := testQoSServicePrepare(t) + s.netmapSource = &utilTesting.TestNetmapSource{} + t.Run("IO tag internal defined, signed with netmap key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagInternal.String()) + ctx = s.AdjustIncomingTag(ctx, pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("IO tag critical defined, signed with netmap key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), qos.IOTagCritical.String()) + ctx = s.AdjustIncomingTag(ctx, pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("no IO tag defined, signed with netmap key", func(t *testing.T) { + ctx := s.AdjustIncomingTag(context.Background(), pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) + t.Run("unknown IO tag, signed with netmap key", func(t *testing.T) { + ctx := tagging.ContextWithIOTag(context.Background(), "some IO tag we don't know") + ctx = s.AdjustIncomingTag(ctx, pk.NetmapNode) + tag, ok := tagging.IOTagFromContext(ctx) + require.True(t, ok) + require.Equal(t, qos.IOTagClient.String(), tag) + }) +} + +func testQoSServicePrepare(t *testing.T) (*cfgQoSService, *testQoSServicePublicKeys) { + nmSigner, err := keys.NewPrivateKey() + require.NoError(t, err) + + reqSigner, err := keys.NewPrivateKey() + require.NoError(t, err) + + allowedCritSigner, err := keys.NewPrivateKey() + require.NoError(t, err) + + allowedIntSigner, err := keys.NewPrivateKey() + require.NoError(t, err) + + var node netmap.NodeInfo + node.SetPublicKey(nmSigner.PublicKey().Bytes()) + nm := &netmap.NetMap{} + nm.SetEpoch(100) + nm.SetNodes([]netmap.NodeInfo{node}) + + return &cfgQoSService{ + logger: test.NewLogger(t), + netmapSource: &utilTesting.TestNetmapSource{ + Netmaps: map[uint64]*netmap.NetMap{ + 100: nm, + }, + CurrentEpoch: 100, + }, + allowedCriticalPubs: [][]byte{ + allowedCritSigner.PublicKey().Bytes(), + }, + allowedInternalPubs: [][]byte{ + allowedIntSigner.PublicKey().Bytes(), + }, + }, + &testQoSServicePublicKeys{ + NetmapNode: nmSigner.PublicKey().Bytes(), + Request: reqSigner.PublicKey().Bytes(), + Internal: allowedIntSigner.PublicKey().Bytes(), + Critical: allowedCritSigner.PublicKey().Bytes(), + } +} + +type testQoSServicePublicKeys struct { + NetmapNode []byte + Request []byte + Internal []byte + Critical []byte +} diff --git a/internal/logs/logs.go b/internal/logs/logs.go index 3503c922e..5b42b25ba 100644 --- a/internal/logs/logs.go +++ b/internal/logs/logs.go @@ -512,7 +512,7 @@ const ( FailedToUpdateMultinetConfiguration = "failed to update multinet configuration" FailedToParseIncomingIOTag = "failed to parse incoming IO tag" NotSupportedIncomingIOTagReplacedWithClient = "incoming IO tag is not supported, replaced with `client`" - FailedToGetNetmapToAdjustIOTag = "failed to get netmap to adjust IO tag, replaced with `client`" + FailedToGetNetmapToAdjustIOTag = "failed to get netmap to adjust IO tag" FailedToValidateIncomingIOTag = "failed to validate incoming IO tag, replaced with `client`" WriteCacheFailedToAcquireRPSQuota = "writecache failed to acquire RPS quota to flush object" ) diff --git a/pkg/core/object/fmt_test.go b/pkg/core/object/fmt_test.go index 239a9f389..dc336eb34 100644 --- a/pkg/core/object/fmt_test.go +++ b/pkg/core/object/fmt_test.go @@ -9,6 +9,7 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/container" "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger" + utilTesting "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/testing" objectV2 "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/object" containerSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" @@ -410,11 +411,11 @@ func TestFormatValidator_ValidateTokenIssuer(t *testing.T) { }, ), WithNetmapSource( - &testNetmapSource{ - netmaps: map[uint64]*netmap.NetMap{ + &utilTesting.TestNetmapSource{ + Netmaps: map[uint64]*netmap.NetMap{ curEpoch: currentEpochNM, }, - currentEpoch: curEpoch, + CurrentEpoch: curEpoch, }, ), WithLogger(logger.NewLoggerWrapper(zaptest.NewLogger(t))), @@ -483,12 +484,12 @@ func TestFormatValidator_ValidateTokenIssuer(t *testing.T) { }, ), WithNetmapSource( - &testNetmapSource{ - netmaps: map[uint64]*netmap.NetMap{ + &utilTesting.TestNetmapSource{ + Netmaps: map[uint64]*netmap.NetMap{ curEpoch: currentEpochNM, curEpoch - 1: previousEpochNM, }, - currentEpoch: curEpoch, + CurrentEpoch: curEpoch, }, ), WithLogger(logger.NewLoggerWrapper(zaptest.NewLogger(t))), @@ -559,12 +560,12 @@ func TestFormatValidator_ValidateTokenIssuer(t *testing.T) { }, ), WithNetmapSource( - &testNetmapSource{ - netmaps: map[uint64]*netmap.NetMap{ + &utilTesting.TestNetmapSource{ + Netmaps: map[uint64]*netmap.NetMap{ curEpoch: currentEpochNM, curEpoch - 1: previousEpochNM, }, - currentEpoch: curEpoch, + CurrentEpoch: curEpoch, }, ), WithLogger(logger.NewLoggerWrapper(zaptest.NewLogger(t))), @@ -596,26 +597,3 @@ func (s *testContainerSource) Get(ctx context.Context, cnrID cid.ID) (*container func (s *testContainerSource) DeletionInfo(context.Context, cid.ID) (*container.DelInfo, error) { return nil, nil } - -type testNetmapSource struct { - netmaps map[uint64]*netmap.NetMap - currentEpoch uint64 -} - -func (s *testNetmapSource) GetNetMap(ctx context.Context, diff uint64) (*netmap.NetMap, error) { - if diff >= s.currentEpoch { - return nil, fmt.Errorf("invalid diff") - } - return s.GetNetMapByEpoch(ctx, s.currentEpoch-diff) -} - -func (s *testNetmapSource) GetNetMapByEpoch(ctx context.Context, epoch uint64) (*netmap.NetMap, error) { - if nm, found := s.netmaps[epoch]; found { - return nm, nil - } - return nil, fmt.Errorf("netmap not found") -} - -func (s *testNetmapSource) Epoch(ctx context.Context) (uint64, error) { - return s.currentEpoch, nil -} diff --git a/pkg/util/testing/netmap_source.go b/pkg/util/testing/netmap_source.go new file mode 100644 index 000000000..7373e538f --- /dev/null +++ b/pkg/util/testing/netmap_source.go @@ -0,0 +1,36 @@ +package testing + +import ( + "context" + "errors" + + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" +) + +var ( + errInvalidDiff = errors.New("invalid diff") + errNetmapNotFound = errors.New("netmap not found") +) + +type TestNetmapSource struct { + Netmaps map[uint64]*netmap.NetMap + CurrentEpoch uint64 +} + +func (s *TestNetmapSource) GetNetMap(ctx context.Context, diff uint64) (*netmap.NetMap, error) { + if diff >= s.CurrentEpoch { + return nil, errInvalidDiff + } + return s.GetNetMapByEpoch(ctx, s.CurrentEpoch-diff) +} + +func (s *TestNetmapSource) GetNetMapByEpoch(_ context.Context, epoch uint64) (*netmap.NetMap, error) { + if nm, found := s.Netmaps[epoch]; found { + return nm, nil + } + return nil, errNetmapNotFound +} + +func (s *TestNetmapSource) Epoch(context.Context) (uint64, error) { + return s.CurrentEpoch, nil +}