smartcontract: implement (*ParameterContext).AddSignature()

This commit is contained in:
Evgenii Stratonikov 2020-03-04 14:59:58 +03:00
parent 44901ca867
commit cd487e3ad4
2 changed files with 195 additions and 0 deletions

View file

@ -1,14 +1,21 @@
package context package context
import ( import (
"bytes"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"sort"
"strings" "strings"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/wallet"
) )
// ParameterContext represents smartcontract parameter's context. // ParameterContext represents smartcontract parameter's context.
@ -27,6 +34,95 @@ type paramContext struct {
Items map[string]json.RawMessage `json:"items"` Items map[string]json.RawMessage `json:"items"`
} }
type sigWithIndex struct {
index int
sig []byte
}
// NewParameterContext returns ParameterContext with the specified type and item to sign.
func NewParameterContext(typ string, verif io.Serializable) *ParameterContext {
return &ParameterContext{
Type: typ,
Verifiable: verif,
Items: make(map[util.Uint160]*Item),
}
}
// AddSignature adds a signature for the specified contract and public key.
func (c *ParameterContext) AddSignature(ctr *wallet.Contract, pub *keys.PublicKey, sig []byte) error {
item := c.getItemForContract(ctr)
if pubs, ok := vm.ParseMultiSigContract(ctr.Script); ok {
if item.GetSignature(pub) != nil {
return errors.New("signature is already added")
}
pubBytes := pub.Bytes()
var contained bool
for i := range pubs {
if bytes.Equal(pubBytes, pubs[i]) {
contained = true
break
}
}
if !contained {
return errors.New("public key is not present in script")
}
item.AddSignature(pub, sig)
if len(item.Signatures) == len(ctr.Parameters) {
indexMap := map[string]int{}
for i := range pubs {
indexMap[hex.EncodeToString(pubs[i])] = i
}
sigs := make([]sigWithIndex, 0, len(item.Signatures))
for pub, sig := range item.Signatures {
sigs = append(sigs, sigWithIndex{index: indexMap[pub], sig: sig})
}
sort.Slice(sigs, func(i, j int) bool {
return sigs[i].index < sigs[j].index
})
for i := range sigs {
item.Parameters[i] = smartcontract.Parameter{
Type: smartcontract.SignatureType,
Value: sigs[i].sig,
}
}
}
return nil
}
index := -1
for i := range ctr.Parameters {
if ctr.Parameters[i].Type == smartcontract.SignatureType {
if index >= 0 {
return errors.New("multiple signature parameters in non-multisig contract")
}
index = i
}
}
if index == -1 {
return errors.New("missing signature parameter")
}
item.Parameters[index].Value = sig
return nil
}
func (c *ParameterContext) getItemForContract(ctr *wallet.Contract) *Item {
h := ctr.ScriptHash()
if item, ok := c.Items[h]; ok {
return item
}
params := make([]smartcontract.Parameter, len(ctr.Parameters))
for i := range params {
params[i].Type = ctr.Parameters[i].Type
}
item := &Item{
Script: h,
Parameters: params,
Signatures: make(map[string][]byte),
}
c.Items[h] = item
return item
}
// MarshalJSON implements json.Marshaler interface. // MarshalJSON implements json.Marshaler interface.
func (c ParameterContext) MarshalJSON() ([]byte, error) { func (c ParameterContext) MarshalJSON() ([]byte, error) {
bw := io.NewBufBinWriter() bw := io.NewBufBinWriter()

View file

@ -9,9 +9,89 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/wallet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestParameterContext_AddSignatureSimpleContract(t *testing.T) {
tx := getContractTx()
priv, err := keys.NewPrivateKey()
require.NoError(t, err)
pub := priv.PublicKey()
sig := priv.Sign(tx.GetSignedPart())
t.Run("invalid contract", func(t *testing.T) {
c := NewParameterContext("Neo.Core.ContractTransaction", tx)
ctr := &wallet.Contract{
Script: pub.GetVerificationScript(),
Parameters: []wallet.ContractParam{
newParam(smartcontract.SignatureType, "parameter0"),
newParam(smartcontract.SignatureType, "parameter1"),
},
}
require.Error(t, c.AddSignature(ctr, pub, sig))
if item := c.Items[ctr.ScriptHash()]; item != nil {
require.Nil(t, item.Parameters[0].Value)
}
ctr.Parameters = ctr.Parameters[:0]
require.Error(t, c.AddSignature(ctr, pub, sig))
if item := c.Items[ctr.ScriptHash()]; item != nil {
require.Nil(t, item.Parameters[0].Value)
}
})
c := NewParameterContext("Neo.Core.ContractTransaction", tx)
ctr := &wallet.Contract{
Script: pub.GetVerificationScript(),
Parameters: []wallet.ContractParam{newParam(smartcontract.SignatureType, "parameter0")},
}
require.NoError(t, c.AddSignature(ctr, pub, sig))
item := c.Items[ctr.ScriptHash()]
require.NotNil(t, item)
require.Equal(t, sig, item.Parameters[0].Value)
}
func TestParameterContext_AddSignatureMultisig(t *testing.T) {
tx := getContractTx()
c := NewParameterContext("Neo.Core.ContractTransaction", tx)
privs, pubs := getPrivateKeys(t, 4)
pubsCopy := make(keys.PublicKeys, len(pubs))
copy(pubsCopy, pubs)
script, err := smartcontract.CreateMultiSigRedeemScript(3, pubsCopy)
require.NoError(t, err)
ctr := &wallet.Contract{
Script: script,
Parameters: []wallet.ContractParam{
newParam(smartcontract.SignatureType, "parameter0"),
newParam(smartcontract.SignatureType, "parameter1"),
newParam(smartcontract.SignatureType, "parameter2"),
},
}
data := tx.GetSignedPart()
priv, err := keys.NewPrivateKey()
require.NoError(t, err)
sig := priv.Sign(data)
require.Error(t, c.AddSignature(ctr, priv.PublicKey(), sig))
indices := []int{2, 3, 0} // random order
for _, i := range indices {
sig := privs[i].Sign(data)
require.NoError(t, c.AddSignature(ctr, pubs[i], sig))
require.Error(t, c.AddSignature(ctr, pubs[i], sig))
item := c.Items[ctr.ScriptHash()]
require.NotNil(t, item)
require.Equal(t, sig, item.GetSignature(pubs[i]))
}
item := c.Items[ctr.ScriptHash()]
for i := range item.Parameters {
require.NotNil(t, item.Parameters[i].Value)
}
}
func TestParameterContext_MarshalJSON(t *testing.T) { func TestParameterContext_MarshalJSON(t *testing.T) {
priv, err := keys.NewPrivateKey() priv, err := keys.NewPrivateKey()
require.NoError(t, err) require.NoError(t, err)
@ -45,6 +125,25 @@ func TestParameterContext_MarshalJSON(t *testing.T) {
require.Equal(t, expected, actual) require.Equal(t, expected, actual)
} }
func getPrivateKeys(t *testing.T, n int) ([]*keys.PrivateKey, []*keys.PublicKey) {
privs := make([]*keys.PrivateKey, n)
pubs := make([]*keys.PublicKey, n)
for i := range privs {
var err error
privs[i], err = keys.NewPrivateKey()
require.NoError(t, err)
pubs[i] = privs[i].PublicKey()
}
return privs, pubs
}
func newParam(typ smartcontract.ParamType, name string) wallet.ContractParam {
return wallet.ContractParam{
Name: name,
Type: typ,
}
}
func getContractTx() *transaction.Transaction { func getContractTx() *transaction.Transaction {
tx := transaction.NewContractTX() tx := transaction.NewContractTX()
tx.AddInput(&transaction.Input{ tx.AddInput(&transaction.Input{