package service import ( "testing" "github.com/pkg/errors" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type mockedRequest struct { msg string name string code codes.Code handler TTLCondition RequestMetaHeader } func TestMetaRequest(t *testing.T) { tests := []mockedRequest{ { name: "direct to ir node", handler: IRNonForwarding(InnerRingNode), RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, }, { code: codes.InvalidArgument, msg: ErrIncorrectTTL.Error(), name: "direct to storage node", handler: IRNonForwarding(StorageNode), RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, }, { msg: ErrZeroTTL.Error(), code: codes.InvalidArgument, name: "zero ttl", handler: IRNonForwarding(StorageNode), RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL}, }, { name: "default to ir node", handler: IRNonForwarding(InnerRingNode), RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, }, { name: "default to storage node", handler: IRNonForwarding(StorageNode), RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, }, { msg: "not found", code: codes.NotFound, name: "custom status error", RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") }, }, { msg: "not found", code: codes.NotFound, name: "custom wrapped status error", RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, handler: func(_ uint32) error { err := status.Error(codes.NotFound, "not found") err = errors.Wrap(err, "some error context") err = errors.Wrap(err, "another error context") return err }, }, } for i := range tests { tt := tests[i] t.Run(tt.name, func(t *testing.T) { before := tt.GetTTL() err := ProcessRequestTTL(&tt, tt.handler) if tt.msg != "" { require.Errorf(t, err, tt.msg) state, ok := status.FromError(err) require.True(t, ok) require.Equal(t, tt.code, state.Code()) require.Equal(t, tt.msg, state.Message()) } else { require.NoError(t, err) require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL()) } }) } }