service: transfer TTL code to a separate file

This commit is contained in:
Leonard Lyubich 2020-05-04 13:38:27 +03:00
parent fc177c4ce3
commit b785eb710a
4 changed files with 174 additions and 148 deletions

View file

@ -1,11 +1,5 @@
package service package service
import (
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type ( type (
// MetaHeader contains meta information of request. // MetaHeader contains meta information of request.
// It provides methods to get or set meta information meta header. // It provides methods to get or set meta information meta header.
@ -15,9 +9,8 @@ type (
ResetMeta() RequestMetaHeader ResetMeta() RequestMetaHeader
RestoreMeta(RequestMetaHeader) RestoreMeta(RequestMetaHeader)
// TTLRequest to verify and update ttl requests. // TTLHeader allows to get and set TTL value of request.
GetTTL() uint32 TTLHeader
SetTTL(uint32)
// EpochHeader gives possibility to get or set epoch in RPC Requests. // EpochHeader gives possibility to get or set epoch in RPC Requests.
EpochHeader EpochHeader
@ -46,20 +39,6 @@ type (
GetRaw() bool GetRaw() bool
SetRaw(bool) SetRaw(bool)
} }
// TTLCondition is closure, that allows to validate request with ttl.
TTLCondition func(ttl uint32) error
)
const (
// ZeroTTL is empty ttl, should produce ErrZeroTTL.
ZeroTTL = iota
// NonForwardingTTL is a ttl that allows direct connections only.
NonForwardingTTL
// SingleForwardingTTL is a ttl that allows connections through another node.
SingleForwardingTTL
) )
// SetVersion sets protocol version to ResponseMetaHeader. // SetVersion sets protocol version to ResponseMetaHeader.
@ -71,9 +50,6 @@ func (m *ResponseMetaHeader) SetEpoch(v uint64) { m.Epoch = v }
// SetVersion sets protocol version to RequestMetaHeader. // SetVersion sets protocol version to RequestMetaHeader.
func (m *RequestMetaHeader) SetVersion(v uint32) { m.Version = v } func (m *RequestMetaHeader) SetVersion(v uint32) { m.Version = v }
// SetTTL sets TTL to RequestMetaHeader.
func (m *RequestMetaHeader) SetTTL(v uint32) { m.TTL = v }
// SetEpoch sets Epoch to RequestMetaHeader. // SetEpoch sets Epoch to RequestMetaHeader.
func (m *RequestMetaHeader) SetEpoch(v uint64) { m.Epoch = v } func (m *RequestMetaHeader) SetEpoch(v uint64) { m.Epoch = v }
@ -91,42 +67,3 @@ func (m *RequestMetaHeader) ResetMeta() RequestMetaHeader {
// RestoreMeta sets current RequestMetaHeader to passed value. // RestoreMeta sets current RequestMetaHeader to passed value.
func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { *m = v } func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { *m = v }
// IRNonForwarding condition that allows NonForwardingTTL only for IR
func IRNonForwarding(role NodeRole) TTLCondition {
return func(ttl uint32) error {
if ttl == NonForwardingTTL && role != InnerRingNode {
return ErrInvalidTTL
}
return nil
}
}
// ProcessRequestTTL validates and update ttl requests.
func ProcessRequestTTL(req MetaHeader, cond ...TTLCondition) error {
ttl := req.GetTTL()
if ttl == ZeroTTL {
return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err()
}
for i := range cond {
if cond[i] == nil {
continue
}
// check specific condition:
if err := cond[i](ttl); err != nil {
if st, ok := status.FromError(errors.Cause(err)); ok {
return st.Err()
}
return status.New(codes.InvalidArgument, err.Error()).Err()
}
}
req.SetTTL(ttl - 1)
return nil
}

View file

@ -3,92 +3,9 @@ package service
import ( import (
"testing" "testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
type mockedRequest struct {
msg string
name string
code codes.Code
handler TTLCondition
RequestMetaHeader
}
func TestMetaRequest(t *testing.T) {
tests := []mockedRequest{
{
name: "direct to ir node",
handler: IRNonForwarding(InnerRingNode),
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
},
{
code: codes.InvalidArgument,
msg: ErrInvalidTTL.Error(),
name: "direct to storage node",
handler: IRNonForwarding(StorageNode),
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
},
{
msg: ErrInvalidTTL.Error(),
code: codes.InvalidArgument,
name: "zero ttl",
handler: IRNonForwarding(StorageNode),
RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL},
},
{
name: "default to ir node",
handler: IRNonForwarding(InnerRingNode),
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
},
{
name: "default to storage node",
handler: IRNonForwarding(StorageNode),
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
},
{
msg: "not found",
code: codes.NotFound,
name: "custom status error",
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") },
},
{
msg: "not found",
code: codes.NotFound,
name: "custom wrapped status error",
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
handler: func(_ uint32) error {
err := status.Error(codes.NotFound, "not found")
err = errors.Wrap(err, "some error context")
err = errors.Wrap(err, "another error context")
return err
},
},
}
for i := range tests {
tt := tests[i]
t.Run(tt.name, func(t *testing.T) {
before := tt.GetTTL()
err := ProcessRequestTTL(&tt, tt.handler)
if tt.msg != "" {
require.Errorf(t, err, tt.msg)
state, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, tt.code, state.Code())
require.Equal(t, tt.msg, state.Message())
} else {
require.NoError(t, err)
require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL())
}
})
}
}
func TestRequestMetaHeader_SetEpoch(t *testing.T) { func TestRequestMetaHeader_SetEpoch(t *testing.T) {
m := new(ResponseMetaHeader) m := new(ResponseMetaHeader)
epoch := uint64(3) epoch := uint64(3)

73
service/ttl.go Normal file
View file

@ -0,0 +1,73 @@
package service
import (
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// TTLHeader is an interface of the container of a numerical TTL value.
type TTLHeader interface {
GetTTL() uint32
SetTTL(uint32)
}
// TTLCondition is a function type that used to verify that TTL value match a specific criterion.
// Nil error indicates compliance with the criterion.
type TTLCondition func(ttl uint32) error
// TTL constants.
const (
// ZeroTTL is an upper bound of invalid TTL values.
ZeroTTL = iota
// NonForwardingTTL is a TTL value that does not imply a request forwarding.
NonForwardingTTL
// SingleForwardingTTL is a TTL value that imply potential forwarding with NonForwardingTTL.
SingleForwardingTTL
)
// SetTTL is a TTL field setter.
func (m *RequestMetaHeader) SetTTL(v uint32) {
m.TTL = v
}
// IRNonForwarding condition that allows NonForwardingTTL only for IR.
func IRNonForwarding(role NodeRole) TTLCondition {
return func(ttl uint32) error {
if ttl == NonForwardingTTL && role != InnerRingNode {
return ErrInvalidTTL
}
return nil
}
}
// ProcessRequestTTL validates and updates requests with TTL.
func ProcessRequestTTL(req TTLHeader, cond ...TTLCondition) error {
ttl := req.GetTTL()
if ttl == ZeroTTL {
return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err()
}
for i := range cond {
if cond[i] == nil {
continue
}
// check specific condition:
if err := cond[i](ttl); err != nil {
if st, ok := status.FromError(errors.Cause(err)); ok {
return st.Err()
}
return status.New(codes.InvalidArgument, err.Error()).Err()
}
}
req.SetTTL(ttl - 1)
return nil
}

99
service/ttl_test.go Normal file
View file

@ -0,0 +1,99 @@
package service
import (
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type mockedRequest struct {
msg string
name string
code codes.Code
handler TTLCondition
RequestMetaHeader
}
func TestMetaRequest(t *testing.T) {
tests := []mockedRequest{
{
name: "direct to ir node",
handler: IRNonForwarding(InnerRingNode),
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
},
{
code: codes.InvalidArgument,
msg: ErrInvalidTTL.Error(),
name: "direct to storage node",
handler: IRNonForwarding(StorageNode),
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
},
{
msg: ErrInvalidTTL.Error(),
code: codes.InvalidArgument,
name: "zero ttl",
handler: IRNonForwarding(StorageNode),
RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL},
},
{
name: "default to ir node",
handler: IRNonForwarding(InnerRingNode),
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
},
{
name: "default to storage node",
handler: IRNonForwarding(StorageNode),
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
},
{
msg: "not found",
code: codes.NotFound,
name: "custom status error",
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") },
},
{
msg: "not found",
code: codes.NotFound,
name: "custom wrapped status error",
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
handler: func(_ uint32) error {
err := status.Error(codes.NotFound, "not found")
err = errors.Wrap(err, "some error context")
err = errors.Wrap(err, "another error context")
return err
},
},
}
for i := range tests {
tt := tests[i]
t.Run(tt.name, func(t *testing.T) {
before := tt.GetTTL()
err := ProcessRequestTTL(&tt, tt.handler)
if tt.msg != "" {
require.Errorf(t, err, tt.msg)
state, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, tt.code, state.Code())
require.Equal(t, tt.msg, state.Message())
} else {
require.NoError(t, err)
require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL())
}
})
}
}
func TestRequestMetaHeader_SetTTL(t *testing.T) {
m := new(RequestMetaHeader)
ttl := uint32(3)
m.SetTTL(ttl)
require.Equal(t, ttl, m.GetTTL())
}