From 2d53ebf9c42dcfc95c8675c1083570a06c631614 Mon Sep 17 00:00:00 2001 From: Leonard Lyubich Date: Mon, 11 May 2020 14:37:16 +0300 Subject: [PATCH] container: implement SignedDataSource on PutRequest message --- container/sign.go | 66 +++++++++++++++++++++++++++ container/sign_test.go | 98 +++++++++++++++++++++++++++++++++++++++++ container/types.go | 35 +++++++++++++++ container/types_test.go | 52 ++++++++++++++++++++++ 4 files changed, 251 insertions(+) create mode 100644 container/sign.go create mode 100644 container/sign_test.go diff --git a/container/sign.go b/container/sign.go new file mode 100644 index 0000000..f551989 --- /dev/null +++ b/container/sign.go @@ -0,0 +1,66 @@ +package container + +import ( + "encoding/binary" + "io" +) + +var requestEndianness = binary.BigEndian + +// SignedData returns payload bytes of the request. +func (m PutRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + if _, err := m.ReadSignedData(data); err != nil { + return nil, err + } + + return data, nil +} + +// SignedDataSize returns payload size of the request. +func (m PutRequest) SignedDataSize() (sz int) { + sz += m.GetMessageID().Size() + + sz += 8 + + sz += m.GetOwnerID().Size() + + rules := m.GetRules() + sz += rules.Size() + + sz += 4 + + return +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the Request size is insufficient, io.ErrUnexpectedEOF returns. +func (m PutRequest) ReadSignedData(p []byte) (int, error) { + if len(p) < m.SignedDataSize() { + return 0, io.ErrUnexpectedEOF + } + + var off int + + off += copy(p[off:], m.GetMessageID().Bytes()) + + requestEndianness.PutUint64(p[off:], m.GetCapacity()) + off += 8 + + off += copy(p[off:], m.GetOwnerID().Bytes()) + + rules := m.GetRules() + // FIXME: implement and use stable functions + n, err := rules.MarshalTo(p[off:]) + off += n + if err != nil { + return off, err + } + + requestEndianness.PutUint32(p[off:], m.GetBasicACL()) + off += 4 + + return off, nil +} diff --git a/container/sign_test.go b/container/sign_test.go new file mode 100644 index 0000000..f1476ed --- /dev/null +++ b/container/sign_test.go @@ -0,0 +1,98 @@ +package container + +import ( + "testing" + + "github.com/nspcc-dev/neofs-api-go/service" + "github.com/nspcc-dev/neofs-crypto/test" + "github.com/stretchr/testify/require" +) + +func TestRequestSign(t *testing.T) { + sk := test.DecodeKey(0) + + type sigType interface { + service.SignedDataWithToken + service.SignKeyPairAccumulator + service.SignKeyPairSource + SetToken(*service.Token) + } + + items := []struct { + constructor func() sigType + payloadCorrupt []func(sigType) + }{ + { // Request + constructor: func() sigType { + return new(PutRequest) + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + req := s.(*PutRequest) + + id := req.GetMessageID() + id[0]++ + + req.SetMessageID(id) + }, + func(s sigType) { + req := s.(*PutRequest) + + req.SetCapacity(req.GetCapacity() + 1) + }, + func(s sigType) { + req := s.(*PutRequest) + + owner := req.GetOwnerID() + owner[0]++ + + req.SetOwnerID(owner) + }, + func(s sigType) { + req := s.(*PutRequest) + + rules := req.GetRules() + rules.ReplFactor++ + + req.SetRules(rules) + }, + func(s sigType) { + req := s.(*PutRequest) + + req.SetBasicACL(req.GetBasicACL() + 1) + }, + }, + }, + } + + for _, item := range items { + { // token corruptions + v := item.constructor() + + token := new(service.Token) + v.SetToken(token) + + require.NoError(t, service.SignDataWithSessionToken(sk, v)) + + require.NoError(t, service.VerifyAccumulatedSignaturesWithToken(v)) + + token.SetSessionKey(append(token.GetSessionKey(), 1)) + + require.Error(t, service.VerifyAccumulatedSignaturesWithToken(v)) + } + + { // payload corruptions + for _, corruption := range item.payloadCorrupt { + v := item.constructor() + + require.NoError(t, service.SignDataWithSessionToken(sk, v)) + + require.NoError(t, service.VerifyAccumulatedSignaturesWithToken(v)) + + corruption(v) + + require.Error(t, service.VerifyAccumulatedSignaturesWithToken(v)) + } + } + } +} diff --git a/container/types.go b/container/types.go index e358e6d..39cef43 100644 --- a/container/types.go +++ b/container/types.go @@ -93,3 +93,38 @@ func NewTestContainer() (*Container, error) { }, }) } + +// GetMessageID is a MessageID field getter. +func (m PutRequest) GetMessageID() MessageID { + return m.MessageID +} + +// SetMessageID is a MessageID field getter. +func (m *PutRequest) SetMessageID(id MessageID) { + m.MessageID = id +} + +// SetCapacity is a Capacity field setter. +func (m *PutRequest) SetCapacity(c uint64) { + m.Capacity = c +} + +// GetOwnerID is an OwnerID field getter. +func (m PutRequest) GetOwnerID() OwnerID { + return m.OwnerID +} + +// SetOwnerID is an OwnerID field setter. +func (m *PutRequest) SetOwnerID(owner OwnerID) { + m.OwnerID = owner +} + +// SetRules is a Rules field setter. +func (m *PutRequest) SetRules(rules netmap.PlacementRule) { + m.Rules = rules +} + +// SetBasicACL is a BasicACL field setter. +func (m *PutRequest) SetBasicACL(acl uint32) { + m.BasicACL = acl +} diff --git a/container/types_test.go b/container/types_test.go index fddccb3..07298bc 100644 --- a/container/types_test.go +++ b/container/types_test.go @@ -55,3 +55,55 @@ func TestCID(t *testing.T) { require.Equal(t, cid1, cid2) }) } + +func TestPutRequestGettersSetters(t *testing.T) { + t.Run("owner", func(t *testing.T) { + owner := OwnerID{1, 2, 3} + m := new(PutRequest) + + m.SetOwnerID(owner) + + require.Equal(t, owner, m.GetOwnerID()) + }) + + t.Run("capacity", func(t *testing.T) { + cp := uint64(3) + m := new(PutRequest) + + m.SetCapacity(cp) + + require.Equal(t, cp, m.GetCapacity()) + }) + + t.Run("message ID", func(t *testing.T) { + id, err := refs.NewUUID() + require.NoError(t, err) + + m := new(PutRequest) + + m.SetMessageID(id) + + require.Equal(t, id, m.GetMessageID()) + }) + + t.Run("rules", func(t *testing.T) { + rules := netmap.PlacementRule{ + ReplFactor: 1, + } + + m := new(PutRequest) + + m.SetRules(rules) + + require.Equal(t, rules, m.GetRules()) + }) + + t.Run("basic ACL", func(t *testing.T) { + bACL := uint32(5) + m := new(PutRequest) + + m.SetBasicACL(bACL) + + require.Equal(t, bACL, m.GetBasicACL()) + }) +}