diff --git a/object/service.go b/object/service.go index 45a8d4b1..0e38d704 100644 --- a/object/service.go +++ b/object/service.go @@ -31,7 +31,7 @@ type ( // All object operations must have TTL, Epoch, Type, Container ID and // permission of usage previous network map. Request interface { - service.MetaHeader + service.SeizedRequestMetaContainer CID() CID Type() RequestType diff --git a/object/sign.go b/object/sign.go new file mode 100644 index 00000000..25d0b2f4 --- /dev/null +++ b/object/sign.go @@ -0,0 +1,272 @@ +package object + +import ( + "encoding/binary" + "io" +) + +// SignedData returns payload bytes of the request. +// +// If payload is nil, ErrHeaderNotFound returns. +func (m PutRequest) SignedData() ([]byte, error) { + sz := m.SignedDataSize() + if sz < 0 { + return nil, ErrHeaderNotFound + } + + data := make([]byte, sz) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m PutRequest) ReadSignedData(p []byte) error { + r := m.GetR() + if r == nil { + return ErrHeaderNotFound + } + + _, err := r.MarshalTo(p) + + return err +} + +// SignedDataSize returns the size of payload of the Put request. +// +// If payload is nil, -1 returns. +func (m PutRequest) SignedDataSize() int { + r := m.GetR() + if r == nil { + return -1 + } + + return r.Size() +} + +// SignedData returns payload bytes of the request. +func (m GetRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m GetRequest) ReadSignedData(p []byte) error { + addr := m.GetAddress() + + if len(p) < m.SignedDataSize() { + return io.ErrUnexpectedEOF + } + + off := copy(p, addr.CID.Bytes()) + + copy(p[off:], addr.ObjectID.Bytes()) + + return nil +} + +// SignedDataSize returns payload size of the request. +func (m GetRequest) SignedDataSize() int { + return addressSize(m.GetAddress()) +} + +// SignedData returns payload bytes of the request. +func (m HeadRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m HeadRequest) ReadSignedData(p []byte) error { + if len(p) < m.SignedDataSize() { + return io.ErrUnexpectedEOF + } + + if m.GetFullHeaders() { + p[0] = 1 + } + + off := 1 + copy(p[1:], m.Address.CID.Bytes()) + + copy(p[off:], m.Address.ObjectID.Bytes()) + + return nil +} + +// SignedDataSize returns payload size of the request. +func (m HeadRequest) SignedDataSize() int { + return addressSize(m.Address) + 1 +} + +// SignedData returns payload bytes of the request. +func (m DeleteRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m DeleteRequest) ReadSignedData(p []byte) error { + if len(p) < m.SignedDataSize() { + return io.ErrUnexpectedEOF + } + + off := copy(p, m.OwnerID.Bytes()) + + copy(p[off:], addressBytes(m.Address)) + + return nil +} + +// SignedDataSize returns payload size of the request. +func (m DeleteRequest) SignedDataSize() int { + return m.OwnerID.Size() + addressSize(m.Address) +} + +// SignedData returns payload bytes of the request. +func (m GetRangeRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m GetRangeRequest) ReadSignedData(p []byte) error { + if len(p) < m.SignedDataSize() { + return io.ErrUnexpectedEOF + } + + n, err := (&m.Range).MarshalTo(p) + if err != nil { + return err + } + + copy(p[n:], addressBytes(m.GetAddress())) + + return nil +} + +// SignedDataSize returns payload size of the request. +func (m GetRangeRequest) SignedDataSize() int { + return (&m.Range).Size() + addressSize(m.GetAddress()) +} + +// SignedData returns payload bytes of the request. +func (m GetRangeHashRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m GetRangeHashRequest) ReadSignedData(p []byte) error { + if len(p) < m.SignedDataSize() { + return io.ErrUnexpectedEOF + } + + var off int + + off += copy(p[off:], addressBytes(m.GetAddress())) + + off += copy(p[off:], rangeSetBytes(m.GetRanges())) + + off += copy(p[off:], m.GetSalt()) + + return nil +} + +// SignedDataSize returns payload size of the request. +func (m GetRangeHashRequest) SignedDataSize() int { + var sz int + + sz += addressSize(m.GetAddress()) + + sz += rangeSetSize(m.GetRanges()) + + sz += len(m.GetSalt()) + + return sz +} + +// SignedData returns payload bytes of the request. +func (m SearchRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + return data, m.ReadSignedData(data) +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m SearchRequest) ReadSignedData(p []byte) error { + if len(p) < m.SignedDataSize() { + return io.ErrUnexpectedEOF + } + + var off int + + off += copy(p[off:], m.CID().Bytes()) + + binary.BigEndian.PutUint32(p[off:], m.GetQueryVersion()) + off += 4 + + copy(p[off:], m.GetQuery()) + + return nil +} + +// SignedDataSize returns payload size of the request. +func (m SearchRequest) SignedDataSize() int { + var sz int + + sz += m.CID().Size() + + sz += 4 // uint32 Version + + sz += len(m.GetQuery()) + + return sz +} + +func rangeSetSize(rs []Range) int { + return 4 + len(rs)*16 // two uint64 fields +} + +func rangeSetBytes(rs []Range) []byte { + data := make([]byte, rangeSetSize(rs)) + + binary.BigEndian.PutUint32(data, uint32(len(rs))) + + off := 4 + + for i := range rs { + binary.BigEndian.PutUint64(data[off:], rs[i].Offset) + off += 8 + + binary.BigEndian.PutUint64(data[off:], rs[i].Length) + off += 8 + } + + return data +} + +func addressSize(addr Address) int { + return addr.CID.Size() + addr.ObjectID.Size() +} + +func addressBytes(addr Address) []byte { + return append(addr.CID.Bytes(), addr.ObjectID.Bytes()...) +} diff --git a/object/sign_test.go b/object/sign_test.go new file mode 100644 index 00000000..4df1c2b1 --- /dev/null +++ b/object/sign_test.go @@ -0,0 +1,189 @@ +package object + +import ( + "testing" + + "github.com/nspcc-dev/neofs-api-go/service" + "github.com/nspcc-dev/neofs-crypto/test" + "github.com/stretchr/testify/require" +) + +func TestSignVerifyRequests(t *testing.T) { + sk := test.DecodeKey(0) + + type sigType interface { + service.SignedDataWithToken + service.SignKeyPairAccumulator + service.SignKeyPairSource + SetToken(*Token) + } + + items := []struct { + constructor func() sigType + payloadCorrupt []func(sigType) + }{ + { // PutRequest.PutHeader + constructor: func() sigType { + return MakePutRequestHeader(new(Object)) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + obj := s.(*PutRequest).GetR().(*PutRequest_Header).Header.GetObject() + obj.SystemHeader.PayloadLength++ + }, + }, + }, + { // PutRequest.Chunk + constructor: func() sigType { + return MakePutRequestChunk(make([]byte, 10)) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + h := s.(*PutRequest).GetR().(*PutRequest_Chunk) + h.Chunk[0]++ + }, + }, + }, + { // GetRequest + constructor: func() sigType { + return new(GetRequest) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + s.(*GetRequest).Address.CID[0]++ + }, + func(s sigType) { + s.(*GetRequest).Address.ObjectID[0]++ + }, + }, + }, + { // HeadRequest + constructor: func() sigType { + return new(HeadRequest) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + s.(*HeadRequest).Address.CID[0]++ + }, + func(s sigType) { + s.(*HeadRequest).Address.ObjectID[0]++ + }, + func(s sigType) { + s.(*HeadRequest).FullHeaders = true + }, + }, + }, + { // DeleteRequest + constructor: func() sigType { + return new(DeleteRequest) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + s.(*DeleteRequest).OwnerID[0]++ + }, + func(s sigType) { + s.(*DeleteRequest).Address.CID[0]++ + }, + func(s sigType) { + s.(*DeleteRequest).Address.ObjectID[0]++ + }, + }, + }, + { // GetRangeRequest + constructor: func() sigType { + return new(GetRangeRequest) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + s.(*GetRangeRequest).Range.Length++ + }, + func(s sigType) { + s.(*GetRangeRequest).Range.Offset++ + }, + func(s sigType) { + s.(*GetRangeRequest).Address.CID[0]++ + }, + func(s sigType) { + s.(*GetRangeRequest).Address.ObjectID[0]++ + }, + }, + }, + { // GetRangeHashRequest + constructor: func() sigType { + return &GetRangeHashRequest{ + Ranges: []Range{{}}, + Salt: []byte{1, 2, 3}, + } + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + s.(*GetRangeHashRequest).Address.CID[0]++ + }, + func(s sigType) { + s.(*GetRangeHashRequest).Address.ObjectID[0]++ + }, + func(s sigType) { + s.(*GetRangeHashRequest).Salt[0]++ + }, + func(s sigType) { + s.(*GetRangeHashRequest).Ranges[0].Length++ + }, + func(s sigType) { + s.(*GetRangeHashRequest).Ranges[0].Offset++ + }, + func(s sigType) { + s.(*GetRangeHashRequest).Ranges = nil + }, + }, + }, + { // GetRangeHashRequest + constructor: func() sigType { + return &SearchRequest{ + Query: []byte{1, 2, 3}, + } + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + s.(*SearchRequest).ContainerID[0]++ + }, + func(s sigType) { + s.(*SearchRequest).Query[0]++ + }, + func(s sigType) { + s.(*SearchRequest).QueryVersion++ + }, + }, + }, + } + + for _, item := range items { + { // token corruptions + v := item.constructor() + + token := new(Token) + v.SetToken(token) + + require.NoError(t, service.SignDataWithSessionToken(sk, v)) + + require.NoError(t, service.VerifyAccumulatedSignaturesWithToken(v)) + + token.SetSessionKey(append(token.GetSessionKey(), 1)) + + require.Error(t, service.VerifyAccumulatedSignaturesWithToken(v)) + } + + { // payload corruptions + for _, corruption := range item.payloadCorrupt { + v := item.constructor() + + require.NoError(t, service.SignDataWithSessionToken(sk, v)) + + require.NoError(t, service.VerifyAccumulatedSignaturesWithToken(v)) + + corruption(v) + + require.Error(t, service.VerifyAccumulatedSignaturesWithToken(v)) + } + } + } +} diff --git a/service/alias.go b/service/alias.go index 6c22ecef..9a407027 100644 --- a/service/alias.go +++ b/service/alias.go @@ -4,11 +4,17 @@ import ( "github.com/nspcc-dev/neofs-api-go/refs" ) -// TokenID is type alias of UUID ref. +// TokenID is a type alias of UUID ref. type TokenID = refs.UUID -// OwnerID is type alias of OwnerID ref. +// OwnerID is a type alias of OwnerID ref. type OwnerID = refs.OwnerID -// Address is type alias of Address ref. +// Address is a type alias of Address ref. type Address = refs.Address + +// AddressContainer is a type alias of refs.AddressContainer. +type AddressContainer = refs.AddressContainer + +// OwnerIDContainer is a type alias of refs.OwnerIDContainer. +type OwnerIDContainer = refs.OwnerIDContainer diff --git a/service/epoch.go b/service/epoch.go new file mode 100644 index 00000000..7a7a556e --- /dev/null +++ b/service/epoch.go @@ -0,0 +1,11 @@ +package service + +// SetEpoch is an Epoch field setter. +func (m *ResponseMetaHeader) SetEpoch(v uint64) { + m.Epoch = v +} + +// SetEpoch is an Epoch field setter. +func (m *RequestMetaHeader) SetEpoch(v uint64) { + m.Epoch = v +} diff --git a/service/epoch_test.go b/service/epoch_test.go new file mode 100644 index 00000000..47316c05 --- /dev/null +++ b/service/epoch_test.go @@ -0,0 +1,21 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetSetEpoch(t *testing.T) { + v := uint64(5) + + items := []EpochContainer{ + new(ResponseMetaHeader), + new(RequestMetaHeader), + } + + for _, item := range items { + item.SetEpoch(v) + require.Equal(t, v, item.GetEpoch()) + } +} diff --git a/service/errors.go b/service/errors.go new file mode 100644 index 00000000..6241ad2c --- /dev/null +++ b/service/errors.go @@ -0,0 +1,45 @@ +package service + +import "github.com/nspcc-dev/neofs-api-go/internal" + +// ErrNilToken is returned by functions that expect +// a non-nil token argument, but received nil. +const ErrNilToken = internal.Error("token is nil") + +// ErrInvalidTTL means that the TTL value does not +// satisfy a specific criterion. +const ErrInvalidTTL = internal.Error("invalid TTL value") + +// ErrInvalidPublicKeyBytes means that the public key could not be unmarshaled. +const ErrInvalidPublicKeyBytes = internal.Error("cannot load public key") + +// ErrCannotFindOwner is raised when signatures empty in GetOwner. +const ErrCannotFindOwner = internal.Error("cannot find owner public key") + +// ErrWrongOwner is raised when passed OwnerID +// not equal to present PublicKey +const ErrWrongOwner = internal.Error("wrong owner") + +// ErrNilSignedDataSource returned by functions that expect a non-nil +// SignedDataSource, but received nil. +const ErrNilSignedDataSource = internal.Error("signed data source is nil") + +// ErrNilSignatureKeySource is returned by functions that expect a non-nil +// SignatureKeySource, but received nil. +const ErrNilSignatureKeySource = internal.Error("empty key-signature source") + +// ErrEmptyDataWithSignature is returned by functions that expect +// a non-nil DataWithSignature, but received nil. +const ErrEmptyDataWithSignature = internal.Error("empty data with signature") + +// ErrNegativeLength is returned by functions that received +// negative length for slice allocation. +const ErrNegativeLength = internal.Error("negative slice length") + +// ErrNilDataWithTokenSignAccumulator is returned by functions that expect +// a non-nil DataWithTokenSignAccumulator, but received nil. +const ErrNilDataWithTokenSignAccumulator = internal.Error("signed data with token is nil") + +// ErrNilSignatureKeySourceWithToken is returned by functions that expect +// a non-nil SignatureKeySourceWithToken, but received nil. +const ErrNilSignatureKeySourceWithToken = internal.Error("key-signature source with token is nil") diff --git a/service/meta.go b/service/meta.go index 8602dca7..3f017584 100644 --- a/service/meta.go +++ b/service/meta.go @@ -1,141 +1,13 @@ package service -import ( - "github.com/nspcc-dev/neofs-api-go/internal" - "github.com/pkg/errors" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -type ( - // MetaHeader contains meta information of request. - // It provides methods to get or set meta information meta header. - // Also contains methods to reset and restore meta header. - // Also contains methods to get or set request protocol version - MetaHeader interface { - ResetMeta() RequestMetaHeader - RestoreMeta(RequestMetaHeader) - - // TTLRequest to verify and update ttl requests. - GetTTL() uint32 - SetTTL(uint32) - - // EpochHeader gives possibility to get or set epoch in RPC Requests. - EpochHeader - - // VersionHeader allows get or set version of protocol request - VersionHeader - - // RawHeader allows to get and set raw option of request - RawHeader - } - - // EpochHeader interface gives possibility to get or set epoch in RPC Requests. - EpochHeader interface { - GetEpoch() uint64 - SetEpoch(v uint64) - } - - // VersionHeader allows get or set version of protocol request - VersionHeader interface { - GetVersion() uint32 - SetVersion(uint32) - } - - // RawHeader is an interface of the container of a boolean Raw value - RawHeader interface { - GetRaw() bool - SetRaw(bool) - } - - // TTLCondition is closure, that allows to validate request with ttl. - TTLCondition func(ttl uint32) error -) - -const ( - // ZeroTTL is empty ttl, should produce ErrZeroTTL. - ZeroTTL = iota - - // NonForwardingTTL is a ttl that allows direct connections only. - NonForwardingTTL - - // SingleForwardingTTL is a ttl that allows connections through another node. - SingleForwardingTTL -) - -const ( - // ErrZeroTTL is raised when zero ttl is passed. - ErrZeroTTL = internal.Error("zero ttl") - - // ErrIncorrectTTL is raised when NonForwardingTTL is passed and NodeRole != InnerRingNode. - ErrIncorrectTTL = internal.Error("incorrect ttl") -) - -// SetVersion sets protocol version to ResponseMetaHeader. -func (m *ResponseMetaHeader) SetVersion(v uint32) { m.Version = v } - -// SetEpoch sets Epoch to ResponseMetaHeader. -func (m *ResponseMetaHeader) SetEpoch(v uint64) { m.Epoch = v } - -// SetVersion sets protocol version to RequestMetaHeader. -func (m *RequestMetaHeader) SetVersion(v uint32) { m.Version = v } - -// SetTTL sets TTL to RequestMetaHeader. -func (m *RequestMetaHeader) SetTTL(v uint32) { m.TTL = v } - -// SetEpoch sets Epoch to RequestMetaHeader. -func (m *RequestMetaHeader) SetEpoch(v uint64) { m.Epoch = v } - -// SetRaw is a Raw field setter. -func (m *RequestMetaHeader) SetRaw(raw bool) { - m.Raw = raw -} - -// ResetMeta returns current value and sets RequestMetaHeader to empty value. -func (m *RequestMetaHeader) ResetMeta() RequestMetaHeader { +// CutMeta returns current value and sets RequestMetaHeader to empty value. +func (m *RequestMetaHeader) CutMeta() RequestMetaHeader { cp := *m m.Reset() return cp } // RestoreMeta sets current RequestMetaHeader to passed value. -func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { *m = v } - -// IRNonForwarding condition that allows NonForwardingTTL only for IR -func IRNonForwarding(role NodeRole) TTLCondition { - return func(ttl uint32) error { - if ttl == NonForwardingTTL && role != InnerRingNode { - return ErrIncorrectTTL - } - - return nil - } -} - -// ProcessRequestTTL validates and update ttl requests. -func ProcessRequestTTL(req MetaHeader, cond ...TTLCondition) error { - ttl := req.GetTTL() - - if ttl == ZeroTTL { - return status.New(codes.InvalidArgument, ErrZeroTTL.Error()).Err() - } - - for i := range cond { - if cond[i] == nil { - continue - } - - // check specific condition: - if err := cond[i](ttl); err != nil { - if st, ok := status.FromError(errors.Cause(err)); ok { - return st.Err() - } - - return status.New(codes.InvalidArgument, err.Error()).Err() - } - } - - req.SetTTL(ttl - 1) - - return nil +func (m *RequestMetaHeader) RestoreMeta(v RequestMetaHeader) { + *m = v } diff --git a/service/meta_test.go b/service/meta_test.go index 388b6ce5..a0b85ef9 100644 --- a/service/meta_test.go +++ b/service/meta_test.go @@ -3,112 +3,23 @@ package service import ( "testing" - "github.com/pkg/errors" "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) -type mockedRequest struct { - msg string - name string - code codes.Code - handler TTLCondition - RequestMetaHeader -} - -func TestMetaRequest(t *testing.T) { - tests := []mockedRequest{ - { - name: "direct to ir node", - handler: IRNonForwarding(InnerRingNode), - RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, - }, - { - code: codes.InvalidArgument, - msg: ErrIncorrectTTL.Error(), - name: "direct to storage node", - handler: IRNonForwarding(StorageNode), - RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, - }, - { - msg: ErrZeroTTL.Error(), - code: codes.InvalidArgument, - name: "zero ttl", - handler: IRNonForwarding(StorageNode), - RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL}, - }, - { - name: "default to ir node", - handler: IRNonForwarding(InnerRingNode), - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - }, - { - name: "default to storage node", - handler: IRNonForwarding(StorageNode), - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - }, - { - msg: "not found", - code: codes.NotFound, - name: "custom status error", - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") }, - }, - { - msg: "not found", - code: codes.NotFound, - name: "custom wrapped status error", - RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, - handler: func(_ uint32) error { - err := status.Error(codes.NotFound, "not found") - err = errors.Wrap(err, "some error context") - err = errors.Wrap(err, "another error context") - return err - }, +func TestCutRestoreMeta(t *testing.T) { + items := []func() SeizedMetaHeaderContainer{ + func() SeizedMetaHeaderContainer { + m := new(RequestMetaHeader) + m.SetEpoch(1) + return m }, } - for i := range tests { - tt := tests[i] - t.Run(tt.name, func(t *testing.T) { - before := tt.GetTTL() - err := ProcessRequestTTL(&tt, tt.handler) - if tt.msg != "" { - require.Errorf(t, err, tt.msg) + for _, item := range items { + v1 := item() + m1 := v1.CutMeta() + v1.RestoreMeta(m1) - state, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, tt.code, state.Code()) - require.Equal(t, tt.msg, state.Message()) - } else { - require.NoError(t, err) - require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL()) - } - }) + require.Equal(t, item(), v1) } } - -func TestRequestMetaHeader_SetEpoch(t *testing.T) { - m := new(ResponseMetaHeader) - epoch := uint64(3) - m.SetEpoch(epoch) - require.Equal(t, epoch, m.GetEpoch()) -} - -func TestRequestMetaHeader_SetVersion(t *testing.T) { - m := new(ResponseMetaHeader) - version := uint32(3) - m.SetVersion(version) - require.Equal(t, version, m.GetVersion()) -} - -func TestRequestMetaHeader_SetRaw(t *testing.T) { - m := new(RequestMetaHeader) - - m.SetRaw(true) - require.True(t, m.GetRaw()) - - m.SetRaw(false) - require.False(t, m.GetRaw()) -} diff --git a/service/raw.go b/service/raw.go new file mode 100644 index 00000000..0bb4b275 --- /dev/null +++ b/service/raw.go @@ -0,0 +1,6 @@ +package service + +// SetRaw is a Raw field setter. +func (m *RequestMetaHeader) SetRaw(raw bool) { + m.Raw = raw +} diff --git a/service/raw_test.go b/service/raw_test.go new file mode 100644 index 00000000..ad595edc --- /dev/null +++ b/service/raw_test.go @@ -0,0 +1,24 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetSetRaw(t *testing.T) { + items := []RawContainer{ + new(RequestMetaHeader), + } + + for _, item := range items { + // init with false + item.SetRaw(false) + + item.SetRaw(true) + require.True(t, item.GetRaw()) + + item.SetRaw(false) + require.False(t, item.GetRaw()) + } +} diff --git a/service/role.go b/service/role.go index 53bcdf55..4c405c12 100644 --- a/service/role.go +++ b/service/role.go @@ -1,8 +1,5 @@ package service -// NodeRole to identify in Bootstrap service. -type NodeRole int32 - const ( _ NodeRole = iota // InnerRingNode that work like IR node. diff --git a/service/sign.go b/service/sign.go new file mode 100644 index 00000000..f5cdc0b9 --- /dev/null +++ b/service/sign.go @@ -0,0 +1,222 @@ +package service + +import ( + "crypto/ecdsa" + + crypto "github.com/nspcc-dev/neofs-crypto" +) + +type keySign struct { + key *ecdsa.PublicKey + sign []byte +} + +// GetSignature is a sign field getter. +func (s keySign) GetSignature() []byte { + return s.sign +} + +// GetPublicKey is a key field getter, +func (s keySign) GetPublicKey() *ecdsa.PublicKey { + return s.key +} + +// Unites passed key with signature and returns SignKeyPair interface. +func newSignatureKeyPair(key *ecdsa.PublicKey, sign []byte) SignKeyPair { + return &keySign{ + key: key, + sign: sign, + } +} + +// Returns data from DataSignatureAccumulator for signature creation/verification. +// +// If passed DataSignatureAccumulator provides a SignedDataReader interface, data for signature is obtained +// using this interface for optimization. In this case, it is understood that reading into the slice D +// that the method DataForSignature returns does not change D. +// +// If returned length of data is negative, ErrNegativeLength returns. +func dataForSignature(src SignedDataSource) ([]byte, error) { + if src == nil { + return nil, ErrNilSignedDataSource + } + + r, ok := src.(SignedDataReader) + if !ok { + return src.SignedData() + } + + buf := bytesPool.Get().([]byte) + + if size := r.SignedDataSize(); size < 0 { + return nil, ErrNegativeLength + } else if size <= cap(buf) { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + + n, err := r.ReadSignedData(buf) + if err != nil { + return nil, err + } + + return buf[:n], nil + +} + +// DataSignature returns the signature of data obtained using the private key. +// +// If passed data container is nil, ErrNilSignedDataSource returns. +// If passed private key is nil, crypto.ErrEmptyPrivateKey returns. +// If the data container or the signature function returns an error, it is returned directly. +func DataSignature(key *ecdsa.PrivateKey, src SignedDataSource) ([]byte, error) { + if key == nil { + return nil, crypto.ErrEmptyPrivateKey + } + + data, err := dataForSignature(src) + if err != nil { + return nil, err + } + defer bytesPool.Put(data) + + return crypto.Sign(key, data) +} + +// AddSignatureWithKey calculates the data signature and adds it to accumulator with public key. +// +// Any change of data provoke signature breakdown. +// +// Returns signing errors only. +func AddSignatureWithKey(key *ecdsa.PrivateKey, v DataWithSignKeyAccumulator) error { + sign, err := DataSignature(key, v) + if err != nil { + return err + } + + v.AddSignKey(sign, &key.PublicKey) + + return nil +} + +// Checks passed key-signature pairs for data from the passed container. +// +// If passed key-signatures pair set is empty, nil returns immediately. +func verifySignatures(src SignedDataSource, items ...SignKeyPair) error { + if len(items) <= 0 { + return nil + } + + data, err := dataForSignature(src) + if err != nil { + return err + } + defer bytesPool.Put(data) + + for _, signKey := range items { + if err := crypto.Verify( + signKey.GetPublicKey(), + data, + signKey.GetSignature(), + ); err != nil { + return err + } + } + + return nil +} + +// VerifySignatures checks passed key-signature pairs for data from the passed container. +// +// If passed data source is nil, ErrNilSignedDataSource returns. +// If check data is not ready, corresponding error returns. +// If at least one of the pairs is invalid, an error returns. +func VerifySignatures(src SignedDataSource, items ...SignKeyPair) error { + return verifySignatures(src, items...) +} + +// VerifyAccumulatedSignatures checks if accumulated key-signature pairs are valid. +// +// Behaves like VerifySignatures. +// If passed key-signature source is empty, ErrNilSignatureKeySource returns. +func VerifyAccumulatedSignatures(src DataWithSignKeySource) error { + if src == nil { + return ErrNilSignatureKeySource + } + + return verifySignatures(src, src.GetSignKeyPairs()...) +} + +// VerifySignatureWithKey checks data signature from the passed container with passed key. +// +// If passed data with signature is nil, ErrEmptyDataWithSignature returns. +// If passed key is nil, crypto.ErrEmptyPublicKey returns. +// A non-nil error returns if and only if the signature does not pass verification. +func VerifySignatureWithKey(key *ecdsa.PublicKey, src DataWithSignature) error { + if src == nil { + return ErrEmptyDataWithSignature + } else if key == nil { + return crypto.ErrEmptyPublicKey + } + + return verifySignatures( + src, + newSignatureKeyPair( + key, + src.GetSignature(), + ), + ) +} + +// SignDataWithSessionToken calculates data with token signature and adds it to accumulator. +// +// Any change of data or session token info provoke signature breakdown. +// +// If passed private key is nil, crypto.ErrEmptyPrivateKey returns. +// If passed DataWithTokenSignAccumulator is nil, ErrNilDataWithTokenSignAccumulator returns. +func SignDataWithSessionToken(key *ecdsa.PrivateKey, src DataWithTokenSignAccumulator) error { + if src == nil { + return ErrNilDataWithTokenSignAccumulator + } else if r, ok := src.(SignedDataReader); ok { + return AddSignatureWithKey(key, &signDataReaderWithToken{ + SignedDataSource: src, + SignKeyPairAccumulator: src, + + rdr: r, + token: src.GetSessionToken(), + }, + ) + } + + return AddSignatureWithKey(key, &signAccumWithToken{ + SignedDataSource: src, + SignKeyPairAccumulator: src, + + token: src.GetSessionToken(), + }) +} + +// VerifyAccumulatedSignaturesWithToken checks if accumulated key-signature pairs of data with token are valid. +// +// If passed DataWithTokenSignSource is nil, ErrNilSignatureKeySourceWithToken returns. +func VerifyAccumulatedSignaturesWithToken(src DataWithTokenSignSource) error { + if src == nil { + return ErrNilSignatureKeySourceWithToken + } else if r, ok := src.(SignedDataReader); ok { + return VerifyAccumulatedSignatures(&signDataReaderWithToken{ + SignedDataSource: src, + SignKeyPairSource: src, + + rdr: r, + token: src.GetSessionToken(), + }) + } + + return VerifyAccumulatedSignatures(&signAccumWithToken{ + SignedDataSource: src, + SignKeyPairSource: src, + + token: src.GetSessionToken(), + }) +} diff --git a/service/sign_test.go b/service/sign_test.go new file mode 100644 index 00000000..5cb7c409 --- /dev/null +++ b/service/sign_test.go @@ -0,0 +1,326 @@ +package service + +import ( + "crypto/ecdsa" + "crypto/rand" + "errors" + "io" + "testing" + + crypto "github.com/nspcc-dev/neofs-crypto" + "github.com/nspcc-dev/neofs-crypto/test" + "github.com/stretchr/testify/require" +) + +type testSignedDataSrc struct { + err error + data []byte + sig []byte + key *ecdsa.PublicKey + token SessionToken +} + +type testSignedDataReader struct { + *testSignedDataSrc +} + +func (s testSignedDataSrc) GetSignature() []byte { + return s.sig +} + +func (s testSignedDataSrc) GetSignKeyPairs() []SignKeyPair { + return []SignKeyPair{ + newSignatureKeyPair(s.key, s.sig), + } +} + +func (s testSignedDataSrc) SignedData() ([]byte, error) { + return s.data, s.err +} + +func (s *testSignedDataSrc) AddSignKey(sig []byte, key *ecdsa.PublicKey) { + s.key = key + s.sig = sig +} + +func testData(t *testing.T, sz int) []byte { + d := make([]byte, sz) + _, err := rand.Read(d) + require.NoError(t, err) + return d +} + +func (s testSignedDataSrc) GetSessionToken() SessionToken { + return s.token +} + +func (s testSignedDataReader) SignedDataSize() int { + return len(s.data) +} + +func (s testSignedDataReader) ReadSignedData(buf []byte) (int, error) { + if s.err != nil { + return 0, s.err + } + + var err error + if len(buf) < len(s.data) { + err = io.ErrUnexpectedEOF + } + return copy(buf, s.data), err +} + +func TestDataSignature(t *testing.T) { + var err error + + // nil private key + _, err = DataSignature(nil, nil) + require.EqualError(t, err, crypto.ErrEmptyPrivateKey.Error()) + + // create test private key + sk := test.DecodeKey(0) + + // nil private key + _, err = DataSignature(sk, nil) + require.EqualError(t, err, ErrNilSignedDataSource.Error()) + + t.Run("common signed data source", func(t *testing.T) { + // create test data source + src := &testSignedDataSrc{ + data: testData(t, 10), + } + + // create custom error for data source + src.err = errors.New("test error for data source") + + _, err = DataSignature(sk, src) + require.EqualError(t, err, src.err.Error()) + + // reset error to nil + src.err = nil + + // calculate data signature + sig, err := DataSignature(sk, src) + require.NoError(t, err) + + // ascertain that the signature passes verification + require.NoError(t, crypto.Verify(&sk.PublicKey, src.data, sig)) + }) + + t.Run("signed data reader", func(t *testing.T) { + // create test signed data reader + src := &testSignedDataSrc{ + data: testData(t, 10), + } + + // create custom error for signed data reader + src.err = errors.New("test error for signed data reader") + + sig, err := DataSignature(sk, src) + require.EqualError(t, err, src.err.Error()) + + // reset error to nil + src.err = nil + + // calculate data signature + sig, err = DataSignature(sk, src) + require.NoError(t, err) + + // ascertain that the signature passes verification + require.NoError(t, crypto.Verify(&sk.PublicKey, src.data, sig)) + }) +} + +func TestAddSignatureWithKey(t *testing.T) { + require.NoError(t, + AddSignatureWithKey( + test.DecodeKey(0), + &testSignedDataSrc{ + data: testData(t, 10), + }, + ), + ) +} + +func TestVerifySignatures(t *testing.T) { + // empty signatures + require.NoError(t, VerifySignatures(nil)) + + // create test signature source + src := &testSignedDataSrc{ + data: testData(t, 10), + } + + // create private key for test + sk := test.DecodeKey(0) + + // calculate a signature of the data + sig, err := crypto.Sign(sk, src.data) + require.NoError(t, err) + + // ascertain that verification is passed + require.NoError(t, + VerifySignatures( + src, + newSignatureKeyPair(&sk.PublicKey, sig), + ), + ) + + // break the signature + sig[0]++ + + require.Error(t, + VerifySignatures( + src, + newSignatureKeyPair(&sk.PublicKey, sig), + ), + ) + + // restore the signature + sig[0]-- + + // empty data source + require.EqualError(t, + VerifySignatures(nil, nil), + ErrNilSignedDataSource.Error(), + ) + +} + +func TestVerifyAccumulatedSignatures(t *testing.T) { + // nil signature source + require.EqualError(t, + VerifyAccumulatedSignatures(nil), + ErrNilSignatureKeySource.Error(), + ) + + // create test private key + sk := test.DecodeKey(0) + + // create signature source + src := &testSignedDataSrc{ + data: testData(t, 10), + key: &sk.PublicKey, + } + + var err error + + // calculate a signature + src.sig, err = crypto.Sign(sk, src.data) + require.NoError(t, err) + + // ascertain that verification is passed + require.NoError(t, VerifyAccumulatedSignatures(src)) + + // break the signature + src.sig[0]++ + + // ascertain that verification is failed + require.Error(t, VerifyAccumulatedSignatures(src)) +} + +func TestVerifySignatureWithKey(t *testing.T) { + // nil signature source + require.EqualError(t, + VerifySignatureWithKey(nil, nil), + ErrEmptyDataWithSignature.Error(), + ) + + // create test signature source + src := &testSignedDataSrc{ + data: testData(t, 10), + } + + // nil public key + require.EqualError(t, + VerifySignatureWithKey(nil, src), + crypto.ErrEmptyPublicKey.Error(), + ) + + // create test private key + sk := test.DecodeKey(0) + + var err error + + // calculate a signature + src.sig, err = crypto.Sign(sk, src.data) + require.NoError(t, err) + + // ascertain that verification is passed + require.NoError(t, VerifySignatureWithKey(&sk.PublicKey, src)) + + // break the signature + src.sig[0]++ + + // ascertain that verification is failed + require.Error(t, VerifySignatureWithKey(&sk.PublicKey, src)) +} + +func TestSignVerifyDataWithSessionToken(t *testing.T) { + // sign with empty DataWithTokenSignAccumulator + require.EqualError(t, + SignDataWithSessionToken(nil, nil), + ErrNilDataWithTokenSignAccumulator.Error(), + ) + + // verify with empty DataWithTokenSignSource + require.EqualError(t, + VerifyAccumulatedSignaturesWithToken(nil), + ErrNilSignatureKeySourceWithToken.Error(), + ) + + // create test session token + var ( + token = new(Token) + initVerb = Token_Info_Verb(1) + ) + + token.SetVerb(initVerb) + + // create test data with token + src := &testSignedDataSrc{ + data: testData(t, 10), + token: token, + } + + // create test private key + sk := test.DecodeKey(0) + + // sign with private key + require.NoError(t, SignDataWithSessionToken(sk, src)) + + // ascertain that verification is passed + require.NoError(t, VerifyAccumulatedSignaturesWithToken(src)) + + // break the data + src.data[0]++ + + // ascertain that verification is failed + require.Error(t, VerifyAccumulatedSignaturesWithToken(src)) + + // restore the data + src.data[0]-- + + // break the token + token.SetVerb(initVerb + 1) + + // ascertain that verification is failed + require.Error(t, VerifyAccumulatedSignaturesWithToken(src)) + + // restore the token + token.SetVerb(initVerb) + + // ascertain that verification is passed + require.NoError(t, VerifyAccumulatedSignaturesWithToken(src)) + + // wrap to data reader + rdr := &testSignedDataReader{ + testSignedDataSrc: src, + } + + // sign with private key + require.NoError(t, SignDataWithSessionToken(sk, rdr)) + + // ascertain that verification is passed + require.NoError(t, VerifyAccumulatedSignaturesWithToken(rdr)) +} diff --git a/service/token.go b/service/token.go index 077e672b..f431427b 100644 --- a/service/token.go +++ b/service/token.go @@ -3,69 +3,39 @@ package service import ( "crypto/ecdsa" "encoding/binary" + "io" - "github.com/nspcc-dev/neofs-api-go/internal" "github.com/nspcc-dev/neofs-api-go/refs" - crypto "github.com/nspcc-dev/neofs-crypto" ) -// VerbContainer is an interface of the container of a token verb value. -type VerbContainer interface { - GetVerb() Token_Info_Verb - SetVerb(Token_Info_Verb) +type signAccumWithToken struct { + SignedDataSource + SignKeyPairAccumulator + SignKeyPairSource + + token SessionToken } -// TokenIDContainer is an interface of the container of a token ID value. -type TokenIDContainer interface { - GetID() TokenID - SetID(TokenID) +type signDataReaderWithToken struct { + SignedDataSource + SignKeyPairAccumulator + SignKeyPairSource + + rdr SignedDataReader + + token SessionToken } -// CreationEpochContainer is an interface of the container of a creation epoch number. -type CreationEpochContainer interface { - CreationEpoch() uint64 - SetCreationEpoch(uint64) -} +const verbSize = 4 -// ExpirationEpochContainer is an interface of the container of an expiration epoch number. -type ExpirationEpochContainer interface { - ExpirationEpoch() uint64 - SetExpirationEpoch(uint64) -} - -// SessionKeyContainer is an interface of the container of session key bytes. -type SessionKeyContainer interface { - GetSessionKey() []byte - SetSessionKey([]byte) -} - -// SignatureContainer is an interface of the container of signature bytes. -type SignatureContainer interface { - GetSignature() []byte - SetSignature([]byte) -} - -// SessionTokenInfo is an interface that determines the information scope of session token. -type SessionTokenInfo interface { - TokenIDContainer - refs.OwnerIDContainer - VerbContainer - refs.AddressContainer - CreationEpochContainer - ExpirationEpochContainer - SessionKeyContainer -} - -// SessionToken is an interface of token information and signature pair. -type SessionToken interface { - SessionTokenInfo - SignatureContainer -} - -// ErrEmptyToken is raised when passed Token is nil. -const ErrEmptyToken = internal.Error("token is empty") - -var _ SessionToken = (*Token)(nil) +const fixedTokenDataSize = 0 + + refs.UUIDSize + + refs.OwnerIDSize + + verbSize + + refs.UUIDSize + + refs.CIDSize + + 8 + + 8 var tokenEndianness = binary.BigEndian @@ -134,88 +104,132 @@ func (m *Token) SetSignature(sig []byte) { m.Signature = sig } -// Returns byte slice that is used for creation/verification of the token signature. -func verificationTokenData(token SessionToken) []byte { - var sz int - - id := token.GetID() - sz += id.Size() - - ownerID := token.GetOwnerID() - sz += ownerID.Size() - - verb := uint32(token.GetVerb()) - sz += 4 - - addr := token.GetAddress() - sz += addr.CID.Size() + addr.ObjectID.Size() - - cEpoch := token.CreationEpoch() - sz += 8 - - fEpoch := token.ExpirationEpoch() - sz += 8 - - key := token.GetSessionKey() - sz += len(key) - - data := make([]byte, sz) - - var off int - - tokenEndianness.PutUint32(data, verb) - off += 4 - - tokenEndianness.PutUint64(data[off:], cEpoch) - off += 8 - - tokenEndianness.PutUint64(data[off:], fEpoch) - off += 8 - - off += copy(data[off:], id.Bytes()) - off += copy(data[off:], ownerID.Bytes()) - off += copy(data[off:], addr.CID.Bytes()) - off += copy(data[off:], addr.ObjectID.Bytes()) - off += copy(data[off:], key) +// Size returns the size of a binary representation of the verb. +func (x Token_Info_Verb) Size() int { + return verbSize +} +// Bytes returns a binary representation of the verb. +func (x Token_Info_Verb) Bytes() []byte { + data := make([]byte, verbSize) + tokenEndianness.PutUint32(data, uint32(x)) return data } -// SignToken calculates and stores the signature of token information. +// AddSignKey calls a Signature field setter with passed signature. +func (m *Token) AddSignKey(sig []byte, _ *ecdsa.PublicKey) { + m.SetSignature(sig) +} + +// SignedData returns token information in a binary representation. +func (m *Token) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + copyTokenSignedData(data, m) + + return data, nil +} + +// ReadSignedData copies a binary representation of the token information to passed buffer. // -// If passed token is nil, ErrEmptyToken returns. -// If passed private key is nil, crypto.ErrEmptyPrivateKey returns. -func SignToken(token SessionToken, key *ecdsa.PrivateKey) error { - if token == nil { - return ErrEmptyToken - } else if key == nil { - return crypto.ErrEmptyPrivateKey +// If buffer length is less than required, io.ErrUnexpectedEOF returns. +func (m *Token_Info) ReadSignedData(p []byte) (int, error) { + sz := m.SignedDataSize() + if len(p) < sz { + return 0, io.ErrUnexpectedEOF } - sig, err := crypto.Sign(key, verificationTokenData(token)) + copyTokenSignedData(p, m) + + return sz, nil +} + +// SignedDataSize returns the length of signed token information slice. +func (m *Token_Info) SignedDataSize() int { + return tokenInfoSize(m) +} + +func tokenInfoSize(v SessionKeySource) int { + if v == nil { + return 0 + } + return fixedTokenDataSize + len(v.GetSessionKey()) +} + +// Fills passed buffer with signing token information bytes. +// Does not check buffer length, it is understood that enough space is allocated in it. +// +// If passed SessionTokenInfo, buffer remains unchanged. +func copyTokenSignedData(buf []byte, token SessionTokenInfo) { + if token == nil { + return + } + + var off int + + off += copy(buf[off:], token.GetID().Bytes()) + + off += copy(buf[off:], token.GetOwnerID().Bytes()) + + off += copy(buf[off:], token.GetVerb().Bytes()) + + addr := token.GetAddress() + off += copy(buf[off:], addr.CID.Bytes()) + off += copy(buf[off:], addr.ObjectID.Bytes()) + + tokenEndianness.PutUint64(buf[off:], token.CreationEpoch()) + off += 8 + + tokenEndianness.PutUint64(buf[off:], token.ExpirationEpoch()) + off += 8 + + copy(buf[off:], token.GetSessionKey()) +} + +// SignedData concatenates signed data with session token information. Returns concatenation result. +// +// Token bytes are added if and only if token is not nil. +func (s signAccumWithToken) SignedData() ([]byte, error) { + data, err := s.SignedDataSource.SignedData() if err != nil { - return err + return nil, err } - token.SetSignature(sig) + tokenData := make([]byte, tokenInfoSize(s.token)) - return nil + copyTokenSignedData(tokenData, s.token) + + return append(data, tokenData...), nil } -// VerifyTokenSignature checks if token was signed correctly. -// -// If passed token is nil, ErrEmptyToken returns. -// If passed public key is nil, crypto.ErrEmptyPublicKey returns. -func VerifyTokenSignature(token SessionToken, key *ecdsa.PublicKey) error { - if token == nil { - return ErrEmptyToken - } else if key == nil { - return crypto.ErrEmptyPublicKey +func (s signDataReaderWithToken) SignedDataSize() int { + sz := s.rdr.SignedDataSize() + if sz < 0 { + return -1 } - return crypto.Verify( - key, - verificationTokenData(token), - token.GetSignature(), - ) + sz += tokenInfoSize(s.token) + + return sz +} + +func (s signDataReaderWithToken) ReadSignedData(p []byte) (int, error) { + dataSize := s.rdr.SignedDataSize() + if dataSize < 0 { + return 0, ErrNegativeLength + } + + sumSize := dataSize + tokenInfoSize(s.token) + + if len(p) < sumSize { + return 0, io.ErrUnexpectedEOF + } + + if n, err := s.rdr.ReadSignedData(p); err != nil { + return n, err + } + + copyTokenSignedData(p[dataSize:], s.token) + + return sumSize, nil } diff --git a/service/token_test.go b/service/token_test.go index 0b28084e..ce3d2c86 100644 --- a/service/token_test.go +++ b/service/token_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/nspcc-dev/neofs-api-go/refs" - crypto "github.com/nspcc-dev/neofs-crypto" "github.com/nspcc-dev/neofs-crypto/test" "github.com/stretchr/testify/require" ) @@ -90,29 +89,7 @@ func TestTokenGettersSetters(t *testing.T) { } func TestSignToken(t *testing.T) { - // nil token - require.EqualError(t, - SignToken(nil, nil), - ErrEmptyToken.Error(), - ) - - require.EqualError(t, - VerifyTokenSignature(nil, nil), - ErrEmptyToken.Error(), - ) - - var token SessionToken = new(Token) - - // nil key - require.EqualError(t, - SignToken(token, nil), - crypto.ErrEmptyPrivateKey.Error(), - ) - - require.EqualError(t, - VerifyTokenSignature(token, nil), - crypto.ErrEmptyPublicKey.Error(), - ) + token := new(Token) // create private key for signing sk := test.DecodeKey(0) @@ -150,8 +127,8 @@ func TestSignToken(t *testing.T) { token.SetSessionKey(sessionKey) // sign and verify token - require.NoError(t, SignToken(token, sk)) - require.NoError(t, VerifyTokenSignature(token, pk)) + require.NoError(t, AddSignatureWithKey(sk, token)) + require.NoError(t, VerifySignatureWithKey(pk, token)) items := []struct { corrupt func() @@ -235,8 +212,8 @@ func TestSignToken(t *testing.T) { for _, v := range items { v.corrupt() - require.Error(t, VerifyTokenSignature(token, pk)) + require.Error(t, VerifySignatureWithKey(pk, token)) v.restore() - require.NoError(t, VerifyTokenSignature(token, pk)) + require.NoError(t, VerifySignatureWithKey(pk, token)) } } diff --git a/service/ttl.go b/service/ttl.go new file mode 100644 index 00000000..28a50921 --- /dev/null +++ b/service/ttl.go @@ -0,0 +1,63 @@ +package service + +import ( + "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TTL constants. +const ( + // ZeroTTL is an upper bound of invalid TTL values. + ZeroTTL = iota + + // NonForwardingTTL is a TTL value that does not imply a request forwarding. + NonForwardingTTL + + // SingleForwardingTTL is a TTL value that imply potential forwarding with NonForwardingTTL. + SingleForwardingTTL +) + +// SetTTL is a TTL field setter. +func (m *RequestMetaHeader) SetTTL(v uint32) { + m.TTL = v +} + +// IRNonForwarding condition that allows NonForwardingTTL only for IR. +func IRNonForwarding(role NodeRole) TTLCondition { + return func(ttl uint32) error { + if ttl == NonForwardingTTL && role != InnerRingNode { + return ErrInvalidTTL + } + + return nil + } +} + +// ProcessRequestTTL validates and updates requests with TTL. +func ProcessRequestTTL(req TTLContainer, cond ...TTLCondition) error { + ttl := req.GetTTL() + + if ttl == ZeroTTL { + return status.New(codes.InvalidArgument, ErrInvalidTTL.Error()).Err() + } + + for i := range cond { + if cond[i] == nil { + continue + } + + // check specific condition: + if err := cond[i](ttl); err != nil { + if st, ok := status.FromError(errors.Cause(err)); ok { + return st.Err() + } + + return status.New(codes.InvalidArgument, err.Error()).Err() + } + } + + req.SetTTL(ttl - 1) + + return nil +} diff --git a/service/ttl_test.go b/service/ttl_test.go new file mode 100644 index 00000000..1c982f55 --- /dev/null +++ b/service/ttl_test.go @@ -0,0 +1,99 @@ +package service + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type mockedRequest struct { + msg string + name string + code codes.Code + handler TTLCondition + RequestMetaHeader +} + +func TestMetaRequest(t *testing.T) { + tests := []mockedRequest{ + { + name: "direct to ir node", + handler: IRNonForwarding(InnerRingNode), + RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, + }, + { + code: codes.InvalidArgument, + msg: ErrInvalidTTL.Error(), + name: "direct to storage node", + handler: IRNonForwarding(StorageNode), + RequestMetaHeader: RequestMetaHeader{TTL: NonForwardingTTL}, + }, + { + msg: ErrInvalidTTL.Error(), + code: codes.InvalidArgument, + name: "zero ttl", + handler: IRNonForwarding(StorageNode), + RequestMetaHeader: RequestMetaHeader{TTL: ZeroTTL}, + }, + { + name: "default to ir node", + handler: IRNonForwarding(InnerRingNode), + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + }, + { + name: "default to storage node", + handler: IRNonForwarding(StorageNode), + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + }, + { + msg: "not found", + code: codes.NotFound, + name: "custom status error", + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + handler: func(_ uint32) error { return status.Error(codes.NotFound, "not found") }, + }, + { + msg: "not found", + code: codes.NotFound, + name: "custom wrapped status error", + RequestMetaHeader: RequestMetaHeader{TTL: SingleForwardingTTL}, + handler: func(_ uint32) error { + err := status.Error(codes.NotFound, "not found") + err = errors.Wrap(err, "some error context") + err = errors.Wrap(err, "another error context") + return err + }, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + before := tt.GetTTL() + err := ProcessRequestTTL(&tt, tt.handler) + if tt.msg != "" { + require.Errorf(t, err, tt.msg) + + state, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, tt.code, state.Code()) + require.Equal(t, tt.msg, state.Message()) + } else { + require.NoError(t, err) + require.NotEqualf(t, before, tt.GetTTL(), "ttl should be changed: %d vs %d", before, tt.GetTTL()) + } + }) + } +} + +func TestRequestMetaHeader_SetTTL(t *testing.T) { + m := new(RequestMetaHeader) + ttl := uint32(3) + + m.SetTTL(ttl) + + require.Equal(t, ttl, m.GetTTL()) +} diff --git a/service/types.go b/service/types.go new file mode 100644 index 00000000..c3148a03 --- /dev/null +++ b/service/types.go @@ -0,0 +1,246 @@ +package service + +import ( + "crypto/ecdsa" +) + +// NodeRole to identify in Bootstrap service. +type NodeRole int32 + +// TTLCondition is a function type that used to verify that TTL values match a specific criterion. +// Nil error indicates compliance with the criterion. +type TTLCondition func(uint32) error + +// RawSource is an interface of the container of a boolean Raw value with read access. +type RawSource interface { + GetRaw() bool +} + +// RawContainer is an interface of the container of a boolean Raw value. +type RawContainer interface { + RawSource + SetRaw(bool) +} + +// VersionSource is an interface of the container of a numerical Version value with read access. +type VersionSource interface { + GetVersion() uint32 +} + +// VersionContainer is an interface of the container of a numerical Version value. +type VersionContainer interface { + VersionSource + SetVersion(uint32) +} + +// EpochSource is an interface of the container of a NeoFS epoch number with read access. +type EpochSource interface { + GetEpoch() uint64 +} + +// EpochContainer is an interface of the container of a NeoFS epoch number. +type EpochContainer interface { + EpochSource + SetEpoch(uint64) +} + +// TTLSource is an interface of the container of a numerical TTL value with read access. +type TTLSource interface { + GetTTL() uint32 +} + +// TTLContainer is an interface of the container of a numerical TTL value. +type TTLContainer interface { + TTLSource + SetTTL(uint32) +} + +// SeizedMetaHeaderContainer is an interface of container of RequestMetaHeader that can be cut and restored. +type SeizedMetaHeaderContainer interface { + CutMeta() RequestMetaHeader + RestoreMeta(RequestMetaHeader) +} + +// RequestMetaContainer is an interface of a fixed set of request meta value containers. +// Contains: +// - TTL value; +// - NeoFS epoch number; +// - Protocol version; +// - Raw toggle option. +type RequestMetaContainer interface { + TTLContainer + EpochContainer + VersionContainer + RawContainer +} + +// SeizedRequestMetaContainer is a RequestMetaContainer with seized meta. +type SeizedRequestMetaContainer interface { + RequestMetaContainer + SeizedMetaHeaderContainer +} + +// VerbSource is an interface of the container of a token verb value with read access. +type VerbSource interface { + GetVerb() Token_Info_Verb +} + +// VerbContainer is an interface of the container of a token verb value. +type VerbContainer interface { + VerbSource + SetVerb(Token_Info_Verb) +} + +// TokenIDSource is an interface of the container of a token ID value with read access. +type TokenIDSource interface { + GetID() TokenID +} + +// TokenIDContainer is an interface of the container of a token ID value. +type TokenIDContainer interface { + TokenIDSource + SetID(TokenID) +} + +// CreationEpochSource is an interface of the container of a creation epoch number with read access. +type CreationEpochSource interface { + CreationEpoch() uint64 +} + +// CreationEpochContainer is an interface of the container of a creation epoch number. +type CreationEpochContainer interface { + CreationEpochSource + SetCreationEpoch(uint64) +} + +// ExpirationEpochSource is an interface of the container of an expiration epoch number with read access. +type ExpirationEpochSource interface { + ExpirationEpoch() uint64 +} + +// ExpirationEpochContainer is an interface of the container of an expiration epoch number. +type ExpirationEpochContainer interface { + ExpirationEpochSource + SetExpirationEpoch(uint64) +} + +// SessionKeySource is an interface of the container of session key bytes with read access. +type SessionKeySource interface { + GetSessionKey() []byte +} + +// SessionKeyContainer is an interface of the container of public session key bytes. +type SessionKeyContainer interface { + SessionKeySource + SetSessionKey([]byte) +} + +// SignatureSource is an interface of the container of signature bytes with read access. +type SignatureSource interface { + GetSignature() []byte +} + +// SignatureContainer is an interface of the container of signature bytes. +type SignatureContainer interface { + SignatureSource + SetSignature([]byte) +} + +// SessionTokenSource is an interface of the container of a SessionToken with read access. +type SessionTokenSource interface { + GetSessionToken() SessionToken +} + +// SessionTokenInfo is an interface of a fixed set of token information value containers. +// Contains: +// - ID of the token; +// - ID of the token's owner; +// - verb of the session; +// - address of the session object; +// - creation epoch number of the token; +// - expiration epoch number of the token; +// - public session key bytes. +type SessionTokenInfo interface { + TokenIDContainer + OwnerIDContainer + VerbContainer + AddressContainer + CreationEpochContainer + ExpirationEpochContainer + SessionKeyContainer +} + +// SessionToken is an interface of token information and signature pair. +type SessionToken interface { + SessionTokenInfo + SignatureContainer +} + +// SignedDataSource is an interface of the container of a data for signing. +type SignedDataSource interface { + // Must return the required for signature byte slice. + // A non-nil error indicates that the data is not ready for signature. + SignedData() ([]byte, error) +} + +// SignedDataReader is an interface of signed data reader. +type SignedDataReader interface { + // Must return the minimum length of the slice for full reading. + // Must return a negative value if the length cannot be calculated. + SignedDataSize() int + + // Must behave like Read method of io.Reader and differ only in the reading of the signed data. + ReadSignedData([]byte) (int, error) +} + +// SignKeyPairAccumulator is an interface of a set of key-signature pairs with append access. +type SignKeyPairAccumulator interface { + AddSignKey([]byte, *ecdsa.PublicKey) +} + +// SignKeyPairSource is an interface of a set of key-signature pairs with read access. +type SignKeyPairSource interface { + GetSignKeyPairs() []SignKeyPair +} + +// SignKeyPair is an interface of key-signature pair with read access. +type SignKeyPair interface { + SignatureSource + GetPublicKey() *ecdsa.PublicKey +} + +// DataWithSignature is an interface of data-signature pair with read access. +type DataWithSignature interface { + SignedDataSource + SignatureSource +} + +// DataWithSignKeyAccumulator is an interface of data and key-signature accumulator pair. +type DataWithSignKeyAccumulator interface { + SignedDataSource + SignKeyPairAccumulator +} + +// DataWithSignKeySource is an interface of data and key-signature source pair. +type DataWithSignKeySource interface { + SignedDataSource + SignKeyPairSource +} + +// SignedDataWithToken is an interface of data-token pair with read access. +type SignedDataWithToken interface { + SignedDataSource + SessionTokenSource +} + +// DataWithTokenSignAccumulator is an interface of data-token pair with signature write access. +type DataWithTokenSignAccumulator interface { + SignedDataWithToken + SignKeyPairAccumulator +} + +// DataWithTokenSignSource is an interface of data-token pair with signature read access. +type DataWithTokenSignSource interface { + SignedDataWithToken + SignKeyPairSource +} diff --git a/service/verify.go b/service/verify.go index 182685d5..beca9925 100644 --- a/service/verify.go +++ b/service/verify.go @@ -35,16 +35,55 @@ type ( } ) -const ( - // ErrCannotLoadPublicKey is raised when cannot unmarshal public key from RequestVerificationHeader_Sign. - ErrCannotLoadPublicKey = internal.Error("cannot load public key") +// GetSessionToken returns SessionToken interface of Token field. +// +// If token field value is nil, nil returns. +func (m RequestVerificationHeader) GetSessionToken() SessionToken { + if t := m.GetToken(); t != nil { + return t + } - // ErrCannotFindOwner is raised when signatures empty in GetOwner. - ErrCannotFindOwner = internal.Error("cannot find owner public key") + return nil +} - // ErrWrongOwner is raised when passed OwnerID not equal to present PublicKey - ErrWrongOwner = internal.Error("wrong owner") -) +// AddSignKey adds new element to Signatures field. +// +// Sets Sign field to passed sign. Set Peer field to marshaled passed key. +func (m *RequestVerificationHeader) AddSignKey(sign []byte, key *ecdsa.PublicKey) { + m.SetSignatures( + append( + m.GetSignatures(), + &RequestVerificationHeader_Signature{ + Sign: sign, + Peer: crypto.MarshalPublicKey(key), + }, + ), + ) +} + +// GetSignKeyPairs returns the elements of Signatures field as SignKeyPair slice. +func (m RequestVerificationHeader) GetSignKeyPairs() []SignKeyPair { + var ( + signs = m.GetSignatures() + res = make([]SignKeyPair, len(signs)) + ) + + for i := range signs { + res[i] = signs[i] + } + + return res +} + +// GetSignature returns the result of a Sign field getter. +func (m RequestVerificationHeader_Signature) GetSignature() []byte { + return m.GetSign() +} + +// GetPublicKey unmarshals and returns the result of a Peer field getter. +func (m RequestVerificationHeader_Signature) GetPublicKey() *ecdsa.PublicKey { + return crypto.UnmarshalPublicKey(m.GetPeer()) +} // SetSignatures replaces signatures stored in RequestVerificationHeader. func (m *RequestVerificationHeader) SetSignatures(signatures []*RequestVerificationHeader_Signature) { @@ -81,7 +120,7 @@ func (m *RequestVerificationHeader) GetOwner() (*ecdsa.PublicKey, error) { return key, nil } - return nil, ErrCannotLoadPublicKey + return nil, ErrInvalidPublicKeyBytes } // GetLastPeer tries to get last peer public key from signatures. @@ -99,7 +138,7 @@ func (m *RequestVerificationHeader) GetLastPeer() (*ecdsa.PublicKey, error) { return key, nil } - return nil, ErrCannotLoadPublicKey + return nil, ErrInvalidPublicKeyBytes } } @@ -129,8 +168,8 @@ var bytesPool = sync.Pool{New: func() interface{} { // new signature to headers. If something went wrong, returns error. func SignRequestHeader(key *ecdsa.PrivateKey, msg VerifiableRequest) error { // ignore meta header - if meta, ok := msg.(MetaHeader); ok { - h := meta.ResetMeta() + if meta, ok := msg.(SeizedRequestMetaContainer); ok { + h := meta.CutMeta() defer func() { meta.RestoreMeta(h) @@ -168,8 +207,8 @@ func SignRequestHeader(key *ecdsa.PrivateKey, msg VerifiableRequest) error { // If something went wrong, returns error. func VerifyRequestHeader(msg VerifiableRequest) error { // ignore meta header - if meta, ok := msg.(MetaHeader); ok { - h := meta.ResetMeta() + if meta, ok := msg.(SeizedRequestMetaContainer); ok { + h := meta.CutMeta() defer func() { meta.RestoreMeta(h) @@ -190,7 +229,7 @@ func VerifyRequestHeader(msg VerifiableRequest) error { key := crypto.UnmarshalPublicKey(peer) if key == nil { - return errors.Wrapf(ErrCannotLoadPublicKey, "%d: %02x", i, peer) + return errors.Wrapf(ErrInvalidPublicKeyBytes, "%d: %02x", i, peer) } if size := msg.Size(); size <= cap(data) { diff --git a/service/version.go b/service/version.go new file mode 100644 index 00000000..6f4839cd --- /dev/null +++ b/service/version.go @@ -0,0 +1,11 @@ +package service + +// SetVersion is a Version field setter. +func (m *ResponseMetaHeader) SetVersion(v uint32) { + m.Version = v +} + +// SetVersion is a Version field setter. +func (m *RequestMetaHeader) SetVersion(v uint32) { + m.Version = v +} diff --git a/service/version_test.go b/service/version_test.go new file mode 100644 index 00000000..d102d305 --- /dev/null +++ b/service/version_test.go @@ -0,0 +1,21 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetSetVersion(t *testing.T) { + v := uint32(7) + + items := []VersionContainer{ + new(ResponseMetaHeader), + new(RequestMetaHeader), + } + + for _, item := range items { + item.SetVersion(v) + require.Equal(t, v, item.GetVersion()) + } +}