accounting: implement SignedDataSource on PutRequest message

This commit is contained in:
Leonard Lyubich 2020-05-11 13:14:31 +03:00
parent 6832d71d48
commit 8c492a7712
4 changed files with 171 additions and 1 deletions

View file

@ -1,6 +1,9 @@
package accounting package accounting
import "io" import (
"encoding/binary"
"io"
)
// SignedData returns payload bytes of the request. // SignedData returns payload bytes of the request.
func (m BalanceRequest) SignedData() ([]byte, error) { func (m BalanceRequest) SignedData() ([]byte, error) {
@ -65,3 +68,57 @@ func (m GetRequest) ReadSignedData(p []byte) (int, error) {
return sz, nil return sz, nil
} }
// SignedData returns payload bytes of the request.
func (m PutRequest) SignedData() ([]byte, error) {
data := make([]byte, m.SignedDataSize())
if _, err := m.ReadSignedData(data); err != nil {
return nil, err
}
return data, nil
}
// SignedDataSize returns payload size of the request.
func (m PutRequest) SignedDataSize() (sz int) {
sz += m.GetOwnerID().Size()
sz += m.GetMessageID().Size()
sz += 8
if amount := m.GetAmount(); amount != nil {
sz += amount.Size()
}
return
}
// ReadSignedData copies payload bytes to passed buffer.
//
// If the buffer size is insufficient, io.ErrUnexpectedEOF returns.
func (m PutRequest) ReadSignedData(p []byte) (int, error) {
if len(p) < m.SignedDataSize() {
return 0, io.ErrUnexpectedEOF
}
var off int
off += copy(p[off:], m.GetOwnerID().Bytes())
off += copy(p[off:], m.GetMessageID().Bytes())
binary.BigEndian.PutUint64(p[off:], m.GetHeight())
off += 8
if amount := m.GetAmount(); amount != nil {
n, err := amount.MarshalTo(p[off:])
off += n
if err != nil {
return off + n, err
}
}
return off, nil
}

View file

@ -3,6 +3,7 @@ package accounting
import ( import (
"testing" "testing"
"github.com/nspcc-dev/neofs-api-go/decimal"
"github.com/nspcc-dev/neofs-api-go/service" "github.com/nspcc-dev/neofs-api-go/service"
"github.com/nspcc-dev/neofs-crypto/test" "github.com/nspcc-dev/neofs-crypto/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -60,6 +61,49 @@ func TestSignBalanceRequest(t *testing.T) {
}, },
}, },
}, },
{ // PutRequest
constructor: func() sigType {
req := new(PutRequest)
amount := decimal.New(1)
req.SetAmount(amount)
return req
},
payloadCorrupt: []func(sigType){
func(s sigType) {
req := s.(*PutRequest)
owner := req.GetOwnerID()
owner[0]++
req.SetOwnerID(owner)
},
func(s sigType) {
req := s.(*PutRequest)
mid := req.GetMessageID()
mid[0]++
req.SetMessageID(mid)
},
func(s sigType) {
req := s.(*PutRequest)
req.SetHeight(req.GetHeight() + 1)
},
func(s sigType) {
req := s.(*PutRequest)
amount := req.GetAmount()
if amount == nil {
req.SetAmount(decimal.New(0))
} else {
req.SetAmount(amount.Add(decimal.New(amount.GetValue())))
}
},
},
},
} }
for _, item := range items { for _, item := range items {

View file

@ -381,3 +381,33 @@ func (m GetRequest) GetOwnerID() OwnerID {
func (m *GetRequest) SetOwnerID(id OwnerID) { func (m *GetRequest) SetOwnerID(id OwnerID) {
m.OwnerID = id m.OwnerID = id
} }
// GetOwnerID is an OwnerID field getter.
func (m PutRequest) GetOwnerID() OwnerID {
return m.OwnerID
}
// SetOwnerID is an OwnerID field setter.
func (m *PutRequest) SetOwnerID(id OwnerID) {
m.OwnerID = id
}
// GetMessageID is a MessageID field getter.
func (m PutRequest) GetMessageID() MessageID {
return m.MessageID
}
// SetMessageID is a MessageID field setter.
func (m *PutRequest) SetMessageID(id MessageID) {
m.MessageID = id
}
// SetAmount is an Amount field setter.
func (m *PutRequest) SetAmount(amount *decimal.Decimal) {
m.Amount = amount
}
// SetHeight is a Height field setter.
func (m *PutRequest) SetHeight(h uint64) {
m.Height = h
}

View file

@ -113,3 +113,42 @@ func TestGetRequestGettersSetters(t *testing.T) {
require.Equal(t, id, m.GetOwnerID()) require.Equal(t, id, m.GetOwnerID())
}) })
} }
func TestPutRequestGettersSetters(t *testing.T) {
t.Run("owner", func(t *testing.T) {
id := OwnerID{1, 2, 3}
m := new(PutRequest)
m.SetOwnerID(id)
require.Equal(t, id, m.GetOwnerID())
})
t.Run("message ID", func(t *testing.T) {
id, err := refs.NewUUID()
require.NoError(t, err)
m := new(PutRequest)
m.SetMessageID(id)
require.Equal(t, id, m.GetMessageID())
})
t.Run("amount", func(t *testing.T) {
amount := decimal.New(1)
m := new(PutRequest)
m.SetAmount(amount)
require.Equal(t, amount, m.GetAmount())
})
t.Run("height", func(t *testing.T) {
h := uint64(3)
m := new(PutRequest)
m.SetHeight(h)
require.Equal(t, h, m.GetHeight())
})
}