service: transfer TTL code to a separate file
This commit is contained in:
parent
fc177c4ce3
commit
b785eb710a
4 changed files with 174 additions and 148 deletions
|
@ -1,11 +1,5 @@
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// MetaHeader contains meta information of request.
|
// MetaHeader contains meta information of request.
|
||||||
// It provides methods to get or set meta information meta header.
|
// It provides methods to get or set meta information meta header.
|
||||||
|
@ -15,9 +9,8 @@ type (
|
||||||
ResetMeta() RequestMetaHeader
|
ResetMeta() RequestMetaHeader
|
||||||
RestoreMeta(RequestMetaHeader)
|
RestoreMeta(RequestMetaHeader)
|
||||||
|
|
||||||
// TTLRequest to verify and update ttl requests.
|
// TTLHeader allows to get and set TTL value of request.
|
||||||
GetTTL() uint32
|
TTLHeader
|
||||||
SetTTL(uint32)
|
|
||||||
|
|
||||||
// EpochHeader gives possibility to get or set epoch in RPC Requests.
|
// EpochHeader gives possibility to get or set epoch in RPC Requests.
|
||||||
EpochHeader
|
EpochHeader
|
||||||
|
@ -46,20 +39,6 @@ type (
|
||||||
GetRaw() bool
|
GetRaw() bool
|
||||||
SetRaw(bool)
|
SetRaw(bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TTLCondition is closure, that allows to validate request with ttl.
|
|
||||||
TTLCondition func(ttl uint32) error
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ZeroTTL is empty ttl, should produce ErrZeroTTL.
|
|
||||||
ZeroTTL = iota
|
|
||||||
|
|
||||||
// NonForwardingTTL is a ttl that allows direct connections only.
|
|
||||||
NonForwardingTTL
|
|
||||||
|
|
||||||
// SingleForwardingTTL is a ttl that allows connections through another node.
|
|
||||||
SingleForwardingTTL
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetVersion sets protocol version to ResponseMetaHeader.
|
// SetVersion sets protocol version to ResponseMetaHeader.
|
||||||
|
@ -71,9 +50,6 @@ func (m *ResponseMetaHeader) SetEpoch(v uint64) { m.Epoch = v }
|
||||||
// SetVersion sets protocol version to RequestMetaHeader.
|
// SetVersion sets protocol version to RequestMetaHeader.
|
||||||
func (m *RequestMetaHeader) SetVersion(v uint32) { m.Version = v }
|
func (m *RequestMetaHeader) SetVersion(v uint32) { m.Version = v }
|
||||||
|
|
||||||
// SetTTL sets TTL to RequestMetaHeader.
|
|
||||||
func (m *RequestMetaHeader) SetTTL(v uint32) { m.TTL = v }
|
|
||||||
|
|
||||||
// SetEpoch sets Epoch to RequestMetaHeader.
|
// SetEpoch sets Epoch to RequestMetaHeader.
|
||||||
func (m *RequestMetaHeader) SetEpoch(v uint64) { m.Epoch = v }
|
func (m *RequestMetaHeader) SetEpoch(v uint64) { m.Epoch = v }
|
||||||
|
|
||||||
|
@ -91,42 +67,3 @@ func (m *RequestMetaHeader) ResetMeta() RequestMetaHeader {
|
||||||
|
|
||||||
// RestoreMeta sets current RequestMetaHeader to passed value.
|
// RestoreMeta sets current RequestMetaHeader to passed value.
|
||||||
func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { *m = v }
|
func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { *m = v }
|
||||||
|
|
||||||
// IRNonForwarding condition that allows NonForwardingTTL only for IR
|
|
||||||
func IRNonForwarding(role NodeRole) TTLCondition {
|
|
||||||
return func(ttl uint32) error {
|
|
||||||
if ttl == NonForwardingTTL && role != InnerRingNode {
|
|
||||||
return ErrInvalidTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessRequestTTL validates and update ttl requests.
|
|
||||||
func ProcessRequestTTL(req MetaHeader, cond ...TTLCondition) error {
|
|
||||||
ttl := req.GetTTL()
|
|
||||||
|
|
||||||
if ttl == ZeroTTL {
|
|
||||||
return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range cond {
|
|
||||||
if cond[i] == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check specific condition:
|
|
||||||
if err := cond[i](ttl); err != nil {
|
|
||||||
if st, ok := status.FromError(errors.Cause(err)); ok {
|
|
||||||
return st.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
return status.New(codes.InvalidArgument, err.Error()).Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.SetTTL(ttl - 1)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,92 +3,9 @@ package service
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockedRequest struct {
|
|
||||||
msg string
|
|
||||||
name string
|
|
||||||
code codes.Code
|
|
||||||
handler TTLCondition
|
|
||||||
RequestMetaHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMetaRequest(t *testing.T) {
|
|
||||||
tests := []mockedRequest{
|
|
||||||
{
|
|
||||||
name: "direct to ir node",
|
|
||||||
handler: IRNonForwarding(InnerRingNode),
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
code: codes.InvalidArgument,
|
|
||||||
msg: ErrInvalidTTL.Error(),
|
|
||||||
name: "direct to storage node",
|
|
||||||
handler: IRNonForwarding(StorageNode),
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
msg: ErrInvalidTTL.Error(),
|
|
||||||
code: codes.InvalidArgument,
|
|
||||||
name: "zero ttl",
|
|
||||||
handler: IRNonForwarding(StorageNode),
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "default to ir node",
|
|
||||||
handler: IRNonForwarding(InnerRingNode),
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "default to storage node",
|
|
||||||
handler: IRNonForwarding(StorageNode),
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
msg: "not found",
|
|
||||||
code: codes.NotFound,
|
|
||||||
name: "custom status error",
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
|
||||||
handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") },
|
|
||||||
},
|
|
||||||
{
|
|
||||||
msg: "not found",
|
|
||||||
code: codes.NotFound,
|
|
||||||
name: "custom wrapped status error",
|
|
||||||
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
|
||||||
handler: func(_ uint32) error {
|
|
||||||
err := status.Error(codes.NotFound, "not found")
|
|
||||||
err = errors.Wrap(err, "some error context")
|
|
||||||
err = errors.Wrap(err, "another error context")
|
|
||||||
return err
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range tests {
|
|
||||||
tt := tests[i]
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
before := tt.GetTTL()
|
|
||||||
err := ProcessRequestTTL(&tt, tt.handler)
|
|
||||||
if tt.msg != "" {
|
|
||||||
require.Errorf(t, err, tt.msg)
|
|
||||||
|
|
||||||
state, ok := status.FromError(err)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, tt.code, state.Code())
|
|
||||||
require.Equal(t, tt.msg, state.Message())
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestMetaHeader_SetEpoch(t *testing.T) {
|
func TestRequestMetaHeader_SetEpoch(t *testing.T) {
|
||||||
m := new(ResponseMetaHeader)
|
m := new(ResponseMetaHeader)
|
||||||
epoch := uint64(3)
|
epoch := uint64(3)
|
||||||
|
|
73
service/ttl.go
Normal file
73
service/ttl.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TTLHeader is an interface of the container of a numerical TTL value.
|
||||||
|
type TTLHeader interface {
|
||||||
|
GetTTL() uint32
|
||||||
|
SetTTL(uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTLCondition is a function type that used to verify that TTL value match a specific criterion.
|
||||||
|
// Nil error indicates compliance with the criterion.
|
||||||
|
type TTLCondition func(ttl uint32) error
|
||||||
|
|
||||||
|
// TTL constants.
|
||||||
|
const (
|
||||||
|
// ZeroTTL is an upper bound of invalid TTL values.
|
||||||
|
ZeroTTL = iota
|
||||||
|
|
||||||
|
// NonForwardingTTL is a TTL value that does not imply a request forwarding.
|
||||||
|
NonForwardingTTL
|
||||||
|
|
||||||
|
// SingleForwardingTTL is a TTL value that imply potential forwarding with NonForwardingTTL.
|
||||||
|
SingleForwardingTTL
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetTTL is a TTL field setter.
|
||||||
|
func (m *RequestMetaHeader) SetTTL(v uint32) {
|
||||||
|
m.TTL = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// IRNonForwarding condition that allows NonForwardingTTL only for IR.
|
||||||
|
func IRNonForwarding(role NodeRole) TTLCondition {
|
||||||
|
return func(ttl uint32) error {
|
||||||
|
if ttl == NonForwardingTTL && role != InnerRingNode {
|
||||||
|
return ErrInvalidTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequestTTL validates and updates requests with TTL.
|
||||||
|
func ProcessRequestTTL(req TTLHeader, cond ...TTLCondition) error {
|
||||||
|
ttl := req.GetTTL()
|
||||||
|
|
||||||
|
if ttl == ZeroTTL {
|
||||||
|
return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cond {
|
||||||
|
if cond[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check specific condition:
|
||||||
|
if err := cond[i](ttl); err != nil {
|
||||||
|
if st, ok := status.FromError(errors.Cause(err)); ok {
|
||||||
|
return st.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return status.New(codes.InvalidArgument, err.Error()).Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetTTL(ttl - 1)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
99
service/ttl_test.go
Normal file
99
service/ttl_test.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockedRequest struct {
|
||||||
|
msg string
|
||||||
|
name string
|
||||||
|
code codes.Code
|
||||||
|
handler TTLCondition
|
||||||
|
RequestMetaHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetaRequest(t *testing.T) {
|
||||||
|
tests := []mockedRequest{
|
||||||
|
{
|
||||||
|
name: "direct to ir node",
|
||||||
|
handler: IRNonForwarding(InnerRingNode),
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
code: codes.InvalidArgument,
|
||||||
|
msg: ErrInvalidTTL.Error(),
|
||||||
|
name: "direct to storage node",
|
||||||
|
handler: IRNonForwarding(StorageNode),
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msg: ErrInvalidTTL.Error(),
|
||||||
|
code: codes.InvalidArgument,
|
||||||
|
name: "zero ttl",
|
||||||
|
handler: IRNonForwarding(StorageNode),
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default to ir node",
|
||||||
|
handler: IRNonForwarding(InnerRingNode),
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default to storage node",
|
||||||
|
handler: IRNonForwarding(StorageNode),
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msg: "not found",
|
||||||
|
code: codes.NotFound,
|
||||||
|
name: "custom status error",
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
||||||
|
handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
msg: "not found",
|
||||||
|
code: codes.NotFound,
|
||||||
|
name: "custom wrapped status error",
|
||||||
|
RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL},
|
||||||
|
handler: func(_ uint32) error {
|
||||||
|
err := status.Error(codes.NotFound, "not found")
|
||||||
|
err = errors.Wrap(err, "some error context")
|
||||||
|
err = errors.Wrap(err, "another error context")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
tt := tests[i]
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
before := tt.GetTTL()
|
||||||
|
err := ProcessRequestTTL(&tt, tt.handler)
|
||||||
|
if tt.msg != "" {
|
||||||
|
require.Errorf(t, err, tt.msg)
|
||||||
|
|
||||||
|
state, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, tt.code, state.Code())
|
||||||
|
require.Equal(t, tt.msg, state.Message())
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMetaHeader_SetTTL(t *testing.T) {
|
||||||
|
m := new(RequestMetaHeader)
|
||||||
|
ttl := uint32(3)
|
||||||
|
|
||||||
|
m.SetTTL(ttl)
|
||||||
|
|
||||||
|
require.Equal(t, ttl, m.GetTTL())
|
||||||
|
}
|
Loading…
Reference in a new issue