container: implement SignedDataSource on PutRequest message

This commit is contained in:
Leonard Lyubich 2020-05-11 14:37:16 +03:00
parent 9327c5f816
commit 2d53ebf9c4
4 changed files with 251 additions and 0 deletions

66
container/sign.go Normal file
View file

@ -0,0 +1,66 @@
package container
import (
"encoding/binary"
"io"
)
var requestEndianness = binary.BigEndian
// SignedData returns payload bytes of the request.
func (m PutRequest) SignedData() ([]byte, error) {
data := make([]byte, m.SignedDataSize())
if _, err := m.ReadSignedData(data); err != nil {
return nil, err
}
return data, nil
}
// SignedDataSize returns payload size of the request.
func (m PutRequest) SignedDataSize() (sz int) {
sz += m.GetMessageID().Size()
sz += 8
sz += m.GetOwnerID().Size()
rules := m.GetRules()
sz += rules.Size()
sz += 4
return
}
// ReadSignedData copies payload bytes to passed buffer.
//
// If the Request size is insufficient, io.ErrUnexpectedEOF returns.
func (m PutRequest) ReadSignedData(p []byte) (int, error) {
if len(p) < m.SignedDataSize() {
return 0, io.ErrUnexpectedEOF
}
var off int
off += copy(p[off:], m.GetMessageID().Bytes())
requestEndianness.PutUint64(p[off:], m.GetCapacity())
off += 8
off += copy(p[off:], m.GetOwnerID().Bytes())
rules := m.GetRules()
// FIXME: implement and use stable functions
n, err := rules.MarshalTo(p[off:])
off += n
if err != nil {
return off, err
}
requestEndianness.PutUint32(p[off:], m.GetBasicACL())
off += 4
return off, nil
}

98
container/sign_test.go Normal file
View file

@ -0,0 +1,98 @@
package container
import (
"testing"
"github.com/nspcc-dev/neofs-api-go/service"
"github.com/nspcc-dev/neofs-crypto/test"
"github.com/stretchr/testify/require"
)
func TestRequestSign(t *testing.T) {
sk := test.DecodeKey(0)
type sigType interface {
service.SignedDataWithToken
service.SignKeyPairAccumulator
service.SignKeyPairSource
SetToken(*service.Token)
}
items := []struct {
constructor func() sigType
payloadCorrupt []func(sigType)
}{
{ // Request
constructor: func() sigType {
return new(PutRequest)
},
payloadCorrupt: []func(sigType){
func(s sigType) {
req := s.(*PutRequest)
id := req.GetMessageID()
id[0]++
req.SetMessageID(id)
},
func(s sigType) {
req := s.(*PutRequest)
req.SetCapacity(req.GetCapacity() + 1)
},
func(s sigType) {
req := s.(*PutRequest)
owner := req.GetOwnerID()
owner[0]++
req.SetOwnerID(owner)
},
func(s sigType) {
req := s.(*PutRequest)
rules := req.GetRules()
rules.ReplFactor++
req.SetRules(rules)
},
func(s sigType) {
req := s.(*PutRequest)
req.SetBasicACL(req.GetBasicACL() + 1)
},
},
},
}
for _, item := range items {
{ // token corruptions
v := item.constructor()
token := new(service.Token)
v.SetToken(token)
require.NoError(t, service.SignDataWithSessionToken(sk, v))
require.NoError(t, service.VerifyAccumulatedSignaturesWithToken(v))
token.SetSessionKey(append(token.GetSessionKey(), 1))
require.Error(t, service.VerifyAccumulatedSignaturesWithToken(v))
}
{ // payload corruptions
for _, corruption := range item.payloadCorrupt {
v := item.constructor()
require.NoError(t, service.SignDataWithSessionToken(sk, v))
require.NoError(t, service.VerifyAccumulatedSignaturesWithToken(v))
corruption(v)
require.Error(t, service.VerifyAccumulatedSignaturesWithToken(v))
}
}
}
}

View file

@ -93,3 +93,38 @@ func NewTestContainer() (*Container, error) {
},
})
}
// GetMessageID is a MessageID field getter.
func (m PutRequest) GetMessageID() MessageID {
return m.MessageID
}
// SetMessageID is a MessageID field getter.
func (m *PutRequest) SetMessageID(id MessageID) {
m.MessageID = id
}
// SetCapacity is a Capacity field setter.
func (m *PutRequest) SetCapacity(c uint64) {
m.Capacity = c
}
// GetOwnerID is an OwnerID field getter.
func (m PutRequest) GetOwnerID() OwnerID {
return m.OwnerID
}
// SetOwnerID is an OwnerID field setter.
func (m *PutRequest) SetOwnerID(owner OwnerID) {
m.OwnerID = owner
}
// SetRules is a Rules field setter.
func (m *PutRequest) SetRules(rules netmap.PlacementRule) {
m.Rules = rules
}
// SetBasicACL is a BasicACL field setter.
func (m *PutRequest) SetBasicACL(acl uint32) {
m.BasicACL = acl
}

View file

@ -55,3 +55,55 @@ func TestCID(t *testing.T) {
require.Equal(t, cid1, cid2)
})
}
func TestPutRequestGettersSetters(t *testing.T) {
t.Run("owner", func(t *testing.T) {
owner := OwnerID{1, 2, 3}
m := new(PutRequest)
m.SetOwnerID(owner)
require.Equal(t, owner, m.GetOwnerID())
})
t.Run("capacity", func(t *testing.T) {
cp := uint64(3)
m := new(PutRequest)
m.SetCapacity(cp)
require.Equal(t, cp, m.GetCapacity())
})
t.Run("message ID", func(t *testing.T) {
id, err := refs.NewUUID()
require.NoError(t, err)
m := new(PutRequest)
m.SetMessageID(id)
require.Equal(t, id, m.GetMessageID())
})
t.Run("rules", func(t *testing.T) {
rules := netmap.PlacementRule{
ReplFactor: 1,
}
m := new(PutRequest)
m.SetRules(rules)
require.Equal(t, rules, m.GetRules())
})
t.Run("basic ACL", func(t *testing.T) {
bACL := uint32(5)
m := new(PutRequest)
m.SetBasicACL(bACL)
require.Equal(t, bACL, m.GetBasicACL())
})
}