From 8c492a7712b44f6fd49ceb5c977801697370e2b6 Mon Sep 17 00:00:00 2001 From: Leonard Lyubich Date: Mon, 11 May 2020 13:14:31 +0300 Subject: [PATCH] accounting: implement SignedDataSource on PutRequest message --- accounting/sign.go | 59 +++++++++++++++++++++++++++++++++++++++- accounting/sign_test.go | 44 ++++++++++++++++++++++++++++++ accounting/types.go | 30 ++++++++++++++++++++ accounting/types_test.go | 39 ++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 1 deletion(-) diff --git a/accounting/sign.go b/accounting/sign.go index 8faeb96d..8da8cf6e 100644 --- a/accounting/sign.go +++ b/accounting/sign.go @@ -1,6 +1,9 @@ package accounting -import "io" +import ( + "encoding/binary" + "io" +) // SignedData returns payload bytes of the request. func (m BalanceRequest) SignedData() ([]byte, error) { @@ -65,3 +68,57 @@ func (m GetRequest) ReadSignedData(p []byte) (int, error) { return sz, nil } + +// SignedData returns payload bytes of the request. +func (m PutRequest) SignedData() ([]byte, error) { + data := make([]byte, m.SignedDataSize()) + + if _, err := m.ReadSignedData(data); err != nil { + return nil, err + } + + return data, nil +} + +// SignedDataSize returns payload size of the request. +func (m PutRequest) SignedDataSize() (sz int) { + sz += m.GetOwnerID().Size() + + sz += m.GetMessageID().Size() + + sz += 8 + + if amount := m.GetAmount(); amount != nil { + sz += amount.Size() + } + + return +} + +// ReadSignedData copies payload bytes to passed buffer. +// +// If the buffer size is insufficient, io.ErrUnexpectedEOF returns. +func (m PutRequest) ReadSignedData(p []byte) (int, error) { + if len(p) < m.SignedDataSize() { + return 0, io.ErrUnexpectedEOF + } + + var off int + + off += copy(p[off:], m.GetOwnerID().Bytes()) + + off += copy(p[off:], m.GetMessageID().Bytes()) + + binary.BigEndian.PutUint64(p[off:], m.GetHeight()) + off += 8 + + if amount := m.GetAmount(); amount != nil { + n, err := amount.MarshalTo(p[off:]) + off += n + if err != nil { + return off + n, err + } + } + + return off, nil +} diff --git a/accounting/sign_test.go b/accounting/sign_test.go index d059ab47..1f88dcfc 100644 --- a/accounting/sign_test.go +++ b/accounting/sign_test.go @@ -3,6 +3,7 @@ package accounting import ( "testing" + "github.com/nspcc-dev/neofs-api-go/decimal" "github.com/nspcc-dev/neofs-api-go/service" "github.com/nspcc-dev/neofs-crypto/test" "github.com/stretchr/testify/require" @@ -60,6 +61,49 @@ func TestSignBalanceRequest(t *testing.T) { }, }, }, + { // PutRequest + constructor: func() sigType { + req := new(PutRequest) + + amount := decimal.New(1) + req.SetAmount(amount) + + return req + }, + payloadCorrupt: []func(sigType){ + func(s sigType) { + req := s.(*PutRequest) + + owner := req.GetOwnerID() + owner[0]++ + + req.SetOwnerID(owner) + }, + func(s sigType) { + req := s.(*PutRequest) + + mid := req.GetMessageID() + mid[0]++ + + req.SetMessageID(mid) + }, + func(s sigType) { + req := s.(*PutRequest) + + req.SetHeight(req.GetHeight() + 1) + }, + func(s sigType) { + req := s.(*PutRequest) + + amount := req.GetAmount() + if amount == nil { + req.SetAmount(decimal.New(0)) + } else { + req.SetAmount(amount.Add(decimal.New(amount.GetValue()))) + } + }, + }, + }, } for _, item := range items { diff --git a/accounting/types.go b/accounting/types.go index 1e4e80a2..3a4b15ea 100644 --- a/accounting/types.go +++ b/accounting/types.go @@ -381,3 +381,33 @@ func (m GetRequest) GetOwnerID() OwnerID { func (m *GetRequest) SetOwnerID(id OwnerID) { m.OwnerID = id } + +// GetOwnerID is an OwnerID field getter. +func (m PutRequest) GetOwnerID() OwnerID { + return m.OwnerID +} + +// SetOwnerID is an OwnerID field setter. +func (m *PutRequest) SetOwnerID(id OwnerID) { + m.OwnerID = id +} + +// GetMessageID is a MessageID field getter. +func (m PutRequest) GetMessageID() MessageID { + return m.MessageID +} + +// SetMessageID is a MessageID field setter. +func (m *PutRequest) SetMessageID(id MessageID) { + m.MessageID = id +} + +// SetAmount is an Amount field setter. +func (m *PutRequest) SetAmount(amount *decimal.Decimal) { + m.Amount = amount +} + +// SetHeight is a Height field setter. +func (m *PutRequest) SetHeight(h uint64) { + m.Height = h +} diff --git a/accounting/types_test.go b/accounting/types_test.go index ea17a8ac..844ea70b 100644 --- a/accounting/types_test.go +++ b/accounting/types_test.go @@ -113,3 +113,42 @@ func TestGetRequestGettersSetters(t *testing.T) { require.Equal(t, id, m.GetOwnerID()) }) } + +func TestPutRequestGettersSetters(t *testing.T) { + t.Run("owner", func(t *testing.T) { + id := OwnerID{1, 2, 3} + m := new(PutRequest) + + m.SetOwnerID(id) + + require.Equal(t, id, m.GetOwnerID()) + }) + + t.Run("message ID", func(t *testing.T) { + id, err := refs.NewUUID() + require.NoError(t, err) + + m := new(PutRequest) + m.SetMessageID(id) + + require.Equal(t, id, m.GetMessageID()) + }) + + t.Run("amount", func(t *testing.T) { + amount := decimal.New(1) + m := new(PutRequest) + + m.SetAmount(amount) + + require.Equal(t, amount, m.GetAmount()) + }) + + t.Run("height", func(t *testing.T) { + h := uint64(3) + m := new(PutRequest) + + m.SetHeight(h) + + require.Equal(t, h, m.GetHeight()) + }) +}