From b785eb710a157cea40c85d7c993826c43e830584 Mon Sep 17 00:00:00 2001 From: Leonard Lyubich Date: Mon, 4 May 2020 13:38:27 +0300 Subject: [PATCH] service: transfer TTL code to a separate file --- service/meta.go | 67 +----------------------------- service/meta_test.go | 83 ------------------------------------- service/ttl.go | 73 ++++++++++++++++++++++++++++++++ service/ttl_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 148 deletions(-) create mode 100644 service/ttl.go create mode 100644 service/ttl_test.go diff --git a/service/meta.go b/service/meta.go index ea1a83d6..2675b79e 100644 --- a/service/meta.go +++ b/service/meta.go @@ -1,11 +1,5 @@ package service -import ( - "github.com/pkg/errors" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - type ( // MetaHeader contains meta information of request. // It provides methods to get or set meta information meta header. @@ -15,9 +9,8 @@ type ( ResetMeta() RequestMetaHeader RestoreMeta(RequestMetaHeader) - // TTLRequest to verify and update ttl requests. - GetTTL() uint32 - SetTTL(uint32) + // TTLHeader allows to get and set TTL value of request. + TTLHeader // EpochHeader gives possibility to get or set epoch in RPC Requests. EpochHeader @@ -46,20 +39,6 @@ type ( GetRaw() bool SetRaw(bool) } - - // TTLCondition is closure, that allows to validate request with ttl. - TTLCondition func(ttl uint32) error -) - -const ( - // ZeroTTL is empty ttl, should produce ErrZeroTTL. - ZeroTTL = iota - - // NonForwardingTTL is a ttl that allows direct connections only. - NonForwardingTTL - - // SingleForwardingTTL is a ttl that allows connections through another node. - SingleForwardingTTL ) // SetVersion sets protocol version to ResponseMetaHeader. @@ -71,9 +50,6 @@ func (m *ResponseMetaHeader) SetEpoch(v uint64) { m.Epoch = v } // SetVersion sets protocol version to RequestMetaHeader. func (m *RequestMetaHeader) SetVersion(v uint32) { m.Version = v } -// SetTTL sets TTL to RequestMetaHeader. -func (m *RequestMetaHeader) SetTTL(v uint32) { m.TTL = v } - // SetEpoch sets Epoch to RequestMetaHeader. func (m *RequestMetaHeader) SetEpoch(v uint64) { m.Epoch = v } @@ -91,42 +67,3 @@ func (m *RequestMetaHeader) ResetMeta() RequestMetaHeader { // RestoreMeta sets current RequestMetaHeader to passed value. func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { *m = v } - -// IRNonForwarding condition that allows NonForwardingTTL only for IR -func IRNonForwarding(role NodeRole) TTLCondition { - return func(ttl uint32) error { - if ttl == NonForwardingTTL && role != InnerRingNode { - return ErrInvalidTTL - } - - return nil - } -} - -// ProcessRequestTTL validates and update ttl requests. -func ProcessRequestTTL(req MetaHeader, cond ...TTLCondition) error { - ttl := req.GetTTL() - - if ttl == ZeroTTL { - return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err() - } - - for i := range cond { - if cond[i] == nil { - continue - } - - // check specific condition: - if err := cond[i](ttl); err != nil { - if st, ok := status.FromError(errors.Cause(err)); ok { - return st.Err() - } - - return status.New(codes.InvalidArgument, err.Error()).Err() - } - } - - req.SetTTL(ttl - 1) - - return nil -} diff --git a/service/meta_test.go b/service/meta_test.go index de77ac81..fb7fb171 100644 --- a/service/meta_test.go +++ b/service/meta_test.go @@ -3,92 +3,9 @@ package service import ( "testing" - "github.com/pkg/errors" "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) -type mockedRequest struct { - msg string - name string - code codes.Code - handler TTLCondition - RequestMetaHeader -} - -func TestMetaRequest(t *testing.T) { - tests := []mockedRequest{ - { - name: "direct to ir node", - handler: IRNonForwarding(InnerRingNode), - RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, - }, - { - code: codes.InvalidArgument, - msg: ErrInvalidTTL.Error(), - name: "direct to storage node", - handler: IRNonForwarding(StorageNode), - RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, - }, - { - msg: ErrInvalidTTL.Error(), - code: codes.InvalidArgument, - name: "zero ttl", - handler: IRNonForwarding(StorageNode), - RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL}, - }, - { - name: "default to ir node", - handler: IRNonForwarding(InnerRingNode), - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - }, - { - name: "default to storage node", - handler: IRNonForwarding(StorageNode), - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - }, - { - msg: "not found", - code: codes.NotFound, - name: "custom status error", - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") }, - }, - { - msg: "not found", - code: codes.NotFound, - name: "custom wrapped status error", - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - handler: func(_ uint32) error { - err := status.Error(codes.NotFound, "not found") - err = errors.Wrap(err, "some error context") - err = errors.Wrap(err, "another error context") - return err - }, - }, - } - - for i := range tests { - tt := tests[i] - t.Run(tt.name, func(t *testing.T) { - before := tt.GetTTL() - err := ProcessRequestTTL(&tt, tt.handler) - if tt.msg != "" { - require.Errorf(t, err, tt.msg) - - state, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, tt.code, state.Code()) - require.Equal(t, tt.msg, state.Message()) - } else { - require.NoError(t, err) - require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL()) - } - }) - } -} - func TestRequestMetaHeader_SetEpoch(t *testing.T) { m := new(ResponseMetaHeader) epoch := uint64(3) diff --git a/service/ttl.go b/service/ttl.go new file mode 100644 index 00000000..f069f547 --- /dev/null +++ b/service/ttl.go @@ -0,0 +1,73 @@ +package service + +import ( + "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TTLHeader is an interface of the container of a numerical TTL value. +type TTLHeader interface { + GetTTL() uint32 + SetTTL(uint32) +} + +// TTLCondition is a function type that used to verify that TTL value match a specific criterion. +// Nil error indicates compliance with the criterion. +type TTLCondition func(ttl uint32) error + +// TTL constants. +const ( + // ZeroTTL is an upper bound of invalid TTL values. + ZeroTTL = iota + + // NonForwardingTTL is a TTL value that does not imply a request forwarding. + NonForwardingTTL + + // SingleForwardingTTL is a TTL value that imply potential forwarding with NonForwardingTTL. + SingleForwardingTTL +) + +// SetTTL is a TTL field setter. +func (m *RequestMetaHeader) SetTTL(v uint32) { + m.TTL = v +} + +// IRNonForwarding condition that allows NonForwardingTTL only for IR. +func IRNonForwarding(role NodeRole) TTLCondition { + return func(ttl uint32) error { + if ttl == NonForwardingTTL && role != InnerRingNode { + return ErrInvalidTTL + } + + return nil + } +} + +// ProcessRequestTTL validates and updates requests with TTL. +func ProcessRequestTTL(req TTLHeader, cond ...TTLCondition) error { + ttl := req.GetTTL() + + if ttl == ZeroTTL { + return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err() + } + + for i := range cond { + if cond[i] == nil { + continue + } + + // check specific condition: + if err := cond[i](ttl); err != nil { + if st, ok := status.FromError(errors.Cause(err)); ok { + return st.Err() + } + + return status.New(codes.InvalidArgument, err.Error()).Err() + } + } + + req.SetTTL(ttl - 1) + + return nil +} diff --git a/service/ttl_test.go b/service/ttl_test.go new file mode 100644 index 00000000..1c982f55 --- /dev/null +++ b/service/ttl_test.go @@ -0,0 +1,99 @@ +package service + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type mockedRequest struct { + msg string + name string + code codes.Code + handler TTLCondition + RequestMetaHeader +} + +func TestMetaRequest(t *testing.T) { + tests := []mockedRequest{ + { + name: "direct to ir node", + handler: IRNonForwarding(InnerRingNode), + RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, + }, + { + code: codes.InvalidArgument, + msg: ErrInvalidTTL.Error(), + name: "direct to storage node", + handler: IRNonForwarding(StorageNode), + RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, + }, + { + msg: ErrInvalidTTL.Error(), + code: codes.InvalidArgument, + name: "zero ttl", + handler: IRNonForwarding(StorageNode), + RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL}, + }, + { + name: "default to ir node", + handler: IRNonForwarding(InnerRingNode), + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + }, + { + name: "default to storage node", + handler: IRNonForwarding(StorageNode), + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + }, + { + msg: "not found", + code: codes.NotFound, + name: "custom status error", + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") }, + }, + { + msg: "not found", + code: codes.NotFound, + name: "custom wrapped status error", + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + handler: func(_ uint32) error { + err := status.Error(codes.NotFound, "not found") + err = errors.Wrap(err, "some error context") + err = errors.Wrap(err, "another error context") + return err + }, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + before := tt.GetTTL() + err := ProcessRequestTTL(&tt, tt.handler) + if tt.msg != "" { + require.Errorf(t, err, tt.msg) + + state, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, tt.code, state.Code()) + require.Equal(t, tt.msg, state.Message()) + } else { + require.NoError(t, err) + require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL()) + } + }) + } +} + +func TestRequestMetaHeader_SetTTL(t *testing.T) { + m := new(RequestMetaHeader) + ttl := uint32(3) + + m.SetTTL(ttl) + + require.Equal(t, ttl, m.GetTTL()) +}