From 4fd6839473a025f59b309c3671363da238129a4f Mon Sep 17 00:00:00 2001 From: Leonard Lyubich Date: Fri, 11 Sep 2020 18:19:36 +0300 Subject: [PATCH] [#140] sdk/object: Implement object format control functions Implement function for calculating, setting and checking object verification fields. Signed-off-by: Leonard Lyubich --- pkg/object/fmt.go | 157 +++++++++++++++++++++++++++++++++++++++++ pkg/object/fmt_test.go | 79 +++++++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 pkg/object/fmt.go create mode 100644 pkg/object/fmt_test.go diff --git a/pkg/object/fmt.go b/pkg/object/fmt.go new file mode 100644 index 0000000..61a08b6 --- /dev/null +++ b/pkg/object/fmt.go @@ -0,0 +1,157 @@ +package object + +import ( + "crypto/ecdsa" + "crypto/sha256" + + "github.com/nspcc-dev/neofs-api-go/pkg" + "github.com/nspcc-dev/neofs-api-go/util/signature" + signatureV2 "github.com/nspcc-dev/neofs-api-go/v2/signature" + "github.com/pkg/errors" +) + +// CalculatePayloadChecksum calculates and returns checksum of +// object payload bytes. +func CalculatePayloadChecksum(payload []byte) *pkg.Checksum { + res := pkg.NewChecksum() + res.SetSHA256(sha256.Sum256(payload)) + + return res +} + +// CalculateAndSetPayloadChecksum calculates checksum of current +// object payload and writes it to the object. +func CalculateAndSetPayloadChecksum(obj *RawObject) { + obj.SetPayloadChecksum( + CalculatePayloadChecksum(obj.GetPayload()), + ) +} + +// VerifyPayloadChecksum checks if payload checksum in the object +// corresponds to its payload. +func VerifyPayloadChecksum(obj *Object) error { + if !pkg.EqualChecksums( + obj.GetPayloadChecksum(), + CalculatePayloadChecksum(obj.GetPayload()), + ) { + return errors.New("payload checksum mismatch") + } + + return nil +} + +// CalculateID calculates identifier for the object. +func CalculateID(obj *Object) (*ID, error) { + data, err := obj.ToV2().GetHeader().StableMarshal(nil) + if err != nil { + return nil, err + } + + id := NewID() + id.SetSHA256(sha256.Sum256(data)) + + return id, nil +} + +// CalculateAndSetID calculates identifier for the object +// and writes the result to it. +func CalculateAndSetID(obj *RawObject) error { + id, err := CalculateID(obj.Object()) + if err != nil { + return err + } + + obj.SetID(id) + + return nil +} + +// VerifyID checks if identifier in the object corresponds to +// its structure. +func VerifyID(obj *Object) error { + id, err := CalculateID(obj) + if err != nil { + return err + } + + if !id.Equal(obj.GetID()) { + return errors.New("incorrect object identifier") + } + + return nil +} + +func CalculateIDSignature(key *ecdsa.PrivateKey, id *ID) (*pkg.Signature, error) { + sig := pkg.NewSignature() + + if err := signature.SignDataWithHandler( + key, + signatureV2.StableMarshalerWrapper{ + SM: id.ToV2(), + }, + func(key, sign []byte) { + sig.SetKey(key) + sig.SetSign(sign) + }, + ); err != nil { + return nil, err + } + + return sig, nil +} + +func CalculateAndSetSignature(key *ecdsa.PrivateKey, obj *RawObject) error { + sig, err := CalculateIDSignature(key, obj.GetID()) + if err != nil { + return err + } + + obj.SetSignature(sig) + + return nil +} + +func VerifyIDSignature(obj *Object) error { + return signature.VerifyDataWithSource( + signatureV2.StableMarshalerWrapper{ + SM: obj.GetID().ToV2(), + }, + func() ([]byte, []byte) { + sig := obj.GetSignature() + + return sig.GetKey(), sig.GetSign() + }, + ) +} + +// SetVerificationFields calculates and sets all verification fields of the object. +func SetVerificationFields(key *ecdsa.PrivateKey, obj *RawObject) error { + CalculateAndSetPayloadChecksum(obj) + + if err := CalculateAndSetID(obj); err != nil { + return errors.Wrap(err, "could not set identifier") + } + + if err := CalculateAndSetSignature(key, obj); err != nil { + return errors.Wrap(err, "could not set signature") + } + + return nil +} + +// CheckVerificationFields checks all verification fields of the object. +func CheckVerificationFields(obj *Object) error { + if err := VerifyIDSignature(obj); err != nil { + return errors.Wrap(err, "invalid signature") + } + + if err := VerifyID(obj); err != nil { + return errors.Wrap(err, "invalid identifier") + } + + if err := VerifyPayloadChecksum(obj); err != nil { + return errors.Wrap(err, "invalid payload checksum") + } + + return nil +} diff --git a/pkg/object/fmt_test.go b/pkg/object/fmt_test.go new file mode 100644 index 0000000..5a65e49 --- /dev/null +++ b/pkg/object/fmt_test.go @@ -0,0 +1,79 @@ +package object + +import ( + "crypto/rand" + "testing" + + "github.com/nspcc-dev/neofs-crypto/test" + "github.com/stretchr/testify/require" +) + +func TestVerificationFields(t *testing.T) { + obj := NewRaw() + + payload := make([]byte, 10) + _, _ = rand.Read(payload) + + obj.SetPayload(payload) + obj.SetPayloadSize(uint64(len(payload))) + + require.NoError(t, SetVerificationFields(test.DecodeKey(-1), obj)) + + require.NoError(t, CheckVerificationFields(obj.Object())) + + items := []struct { + corrupt func() + restore func() + }{ + { + corrupt: func() { + payload[0]++ + }, + restore: func() { + payload[0]-- + }, + }, + { + corrupt: func() { + obj.SetPayloadSize(obj.GetPayloadSize() + 1) + }, + restore: func() { + obj.SetPayloadSize(obj.GetPayloadSize() - 1) + }, + }, + { + corrupt: func() { + obj.GetID().ToV2().GetValue()[0]++ + }, + restore: func() { + obj.GetID().ToV2().GetValue()[0]-- + }, + }, + { + corrupt: func() { + obj.GetSignature().GetKey()[0]++ + }, + restore: func() { + obj.GetSignature().GetKey()[0]-- + }, + }, + { + corrupt: func() { + obj.GetSignature().GetSign()[0]++ + }, + restore: func() { + obj.GetSignature().GetSign()[0]-- + }, + }, + } + + for _, item := range items { + item.corrupt() + + require.Error(t, CheckVerificationFields(obj.Object())) + + item.restore() + + require.NoError(t, CheckVerificationFields(obj.Object())) + } +}