smartcontract: add CreateDefaultMultiSigRedeemScript

And use it where appropriate. Some of our code was just plain wrong (like the
one in GAS contract) and unification is always useful here.
This commit is contained in:
Roman Khimov 2020-08-10 18:58:11 +03:00
parent 80302c5c07
commit dba248236c
6 changed files with 62 additions and 15 deletions

View file

@ -51,11 +51,7 @@ func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block {
func newBlock(cfg config.ProtocolConfiguration, index uint32, prev util.Uint256, txs ...*transaction.Transaction) *block.Block {
validators, _ := validatorsFromConfig(cfg)
vlen := len(validators)
valScript, _ := smartcontract.CreateMultiSigRedeemScript(
vlen-(vlen-1)/3,
validators,
)
valScript, _ := smartcontract.CreateDefaultMultiSigRedeemScript(validators)
witness := transaction.Witness{
VerificationScript: valScript,
}
@ -393,7 +389,7 @@ func addSigners(txs ...*transaction.Transaction) {
func signTx(bc *Blockchain, txs ...*transaction.Transaction) error {
validators := bc.GetStandByValidators()
rawScript, err := smartcontract.CreateMultiSigRedeemScript(bc.config.ValidatorsCount/2+1, validators)
rawScript, err := smartcontract.CreateDefaultMultiSigRedeemScript(validators)
if err != nil {
return fmt.Errorf("failed to sign tx: %w", err)
}

View file

@ -103,8 +103,7 @@ func (g *GAS) OnPersist(ic *interop.Context) error {
}
func getStandbyValidatorsHash(ic *interop.Context) (util.Uint160, error) {
vs := ic.Chain.GetStandByValidators()
s, err := smartcontract.CreateMultiSigRedeemScript(len(vs)/2+1, vs)
s, err := smartcontract.CreateDefaultMultiSigRedeemScript(ic.Chain.GetStandByValidators())
if err != nil {
return util.Uint160{}, err
}

View file

@ -118,11 +118,7 @@ func committeeFromConfig(cfg config.ProtocolConfiguration) ([]*keys.PublicKey, e
}
func getNextConsensusAddress(validators []*keys.PublicKey) (val util.Uint160, err error) {
vlen := len(validators)
raw, err := smartcontract.CreateMultiSigRedeemScript(
vlen-(vlen-1)/3,
validators,
)
raw, err := smartcontract.CreateDefaultMultiSigRedeemScript(validators)
if err != nil {
return val, err
}

View file

@ -64,7 +64,7 @@ func MultisigVerificationScript() []byte {
pubs = append(pubs, priv.PublicKey())
}
script, err := smartcontract.CreateMultiSigRedeemScript(3, pubs)
script, err := smartcontract.CreateDefaultMultiSigRedeemScript(pubs)
if err != nil {
panic(err)
}

View file

@ -10,7 +10,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
)
// CreateMultiSigRedeemScript creates a script runnable by the VM.
// CreateMultiSigRedeemScript creates an "m out of n" type verification script
// where n is the length of publicKeys.
func CreateMultiSigRedeemScript(m int, publicKeys keys.PublicKeys) ([]byte, error) {
if m < 1 {
return nil, fmt.Errorf("param m cannot be smaller or equal to 1 got %d", m)
@ -34,3 +35,11 @@ func CreateMultiSigRedeemScript(m int, publicKeys keys.PublicKeys) ([]byte, erro
return buf.Bytes(), nil
}
// CreateDefaultMultiSigRedeemScript creates an "m out of n" type verification script
// using publicKeys length with the default BFT assumptions of (n - (n-1)/3) for m.
func CreateDefaultMultiSigRedeemScript(publicKeys keys.PublicKeys) ([]byte, error) {
n := len(publicKeys)
m := n - (n-1)/3
return CreateMultiSigRedeemScript(m, publicKeys)
}

View file

@ -36,3 +36,50 @@ func TestCreateMultiSigRedeemScript(t *testing.T) {
assert.Equal(t, opcode.SYSCALL, opcode.Opcode(br.ReadB()))
assert.Equal(t, emit.InteropNameToID([]byte("Neo.Crypto.CheckMultisigWithECDsaSecp256r1")), br.ReadU32LE())
}
func TestCreateDefaultMultiSigRedeemScript(t *testing.T) {
var validators = make([]*keys.PublicKey, 0)
var addKey = func() {
key, err := keys.NewPrivateKey()
require.NoError(t, err)
validators = append(validators, key.PublicKey())
}
var checkM = func(m int) {
validScript, err := CreateMultiSigRedeemScript(m, validators)
require.NoError(t, err)
defaultScript, err := CreateDefaultMultiSigRedeemScript(validators)
require.NoError(t, err)
require.Equal(t, validScript, defaultScript)
}
// 1 out of 1
addKey()
checkM(1)
// 2 out of 2
addKey()
checkM(2)
// 3 out of 4
for i := 0; i < 2; i++ {
addKey()
}
checkM(3)
// 5 out of 6
for i := 0; i < 2; i++ {
addKey()
}
checkM(5)
// 5 out of 7
addKey()
checkM(5)
// 7 out of 10
for i := 0; i < 3; i++ {
addKey()
}
checkM(7)
}