diff --git a/service/meta.go b/service/meta.go index 6b5b8b7..abffdad 100644 --- a/service/meta.go +++ b/service/meta.go @@ -101,6 +101,10 @@ func ProcessRequestTTL(req MetaHeader, cond ...TTLCondition) error { // check specific condition: if err := cond[i](ttl); err != nil { + if st, ok := status.FromError(err); ok { + return st.Err() + } + return status.New(codes.InvalidArgument, err.Error()).Err() } } diff --git a/service/meta_test.go b/service/meta_test.go index 893ca5e..496ea51 100644 --- a/service/meta_test.go +++ b/service/meta_test.go @@ -9,58 +9,65 @@ import ( ) type mockedRequest struct { - msg string - name string - role NodeRole - code codes.Code + msg string + name string + code codes.Code + handler TTLCondition RequestMetaHeader } func TestMetaRequest(t *testing.T) { tests := []mockedRequest{ { - role: InnerRingNode, name: "direct to ir node", + handler: IRNonForwarding(InnerRingNode), RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, }, { - role: StorageNode, code: codes.InvalidArgument, msg: ErrIncorrectTTL.Error(), name: "direct to storage node", + handler: IRNonForwarding(StorageNode), RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, }, { - role: StorageNode, msg: ErrZeroTTL.Error(), code: codes.InvalidArgument, name: "zero ttl", + handler: IRNonForwarding(StorageNode), RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL}, }, { - role: InnerRingNode, name: "default to ir node", + handler: IRNonForwarding(InnerRingNode), RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, }, { - role: StorageNode, name: "default to storage node", + handler: IRNonForwarding(StorageNode), RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, }, + { + msg: "not found", + code: codes.NotFound, + name: "custom status error", + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") }, + }, } for i := range tests { tt := tests[i] t.Run(tt.name, func(t *testing.T) { before := tt.GetTTL() - err := ProcessRequestTTL(&tt, IRNonForwarding(tt.role)) + err := ProcessRequestTTL(&tt, tt.handler) if tt.msg != "" { require.Errorf(t, err, tt.msg) state, ok := status.FromError(err) require.True(t, ok) - require.Equal(t, state.Code(), tt.code) - require.Equal(t, state.Message(), tt.msg) + require.Equal(t, tt.code, state.Code()) + require.Equal(t, tt.msg, state.Message()) } else { require.NoError(t, err) require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL())