Merge pull request #696 from nspcc-dev/tx-verification-fixes

Claim and enrollment TX verification fixes
This commit is contained in:
Roman Khimov 2020-02-27 12:45:36 +03:00 committed by GitHub
commit 7d59fa0066
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 360 additions and 128 deletions

View file

@ -431,20 +431,23 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
} }
// Process TX inputs that are grouped by previous hash. // Process TX inputs that are grouped by previous hash.
for prevHash, inputs := range tx.GroupInputsByPrevHash() { for _, inputs := range transaction.GroupInputsByPrevHash(tx.Inputs) {
prevHash := inputs[0].PrevHash
prevTX, prevTXHeight, err := bc.dao.GetTransaction(prevHash) prevTX, prevTXHeight, err := bc.dao.GetTransaction(prevHash)
if err != nil { if err != nil {
return fmt.Errorf("could not find previous TX: %s", prevHash) return fmt.Errorf("could not find previous TX: %s", prevHash)
} }
for _, input := range inputs { unspent, err := cache.GetUnspentCoinStateOrNew(prevHash)
unspent, err := cache.GetUnspentCoinStateOrNew(input.PrevHash)
if err != nil { if err != nil {
return err return err
} }
unspent.states[input.PrevIndex] = state.CoinSpent spentCoin, err := cache.GetSpentCoinsOrNew(prevHash, prevTXHeight)
if err = cache.PutUnspentCoinState(input.PrevHash, unspent); err != nil { if err != nil {
return err return err
} }
oldSpentCoinLen := len(spentCoin.items)
for _, input := range inputs {
unspent.states[input.PrevIndex] = state.CoinSpent
prevTXOutput := prevTX.Outputs[input.PrevIndex] prevTXOutput := prevTX.Outputs[input.PrevIndex]
account, err := cache.GetAccountStateOrNew(prevTXOutput.ScriptHash) account, err := cache.GetAccountStateOrNew(prevTXOutput.ScriptHash)
if err != nil { if err != nil {
@ -452,11 +455,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
} }
if prevTXOutput.AssetID.Equals(GoverningTokenID()) { if prevTXOutput.AssetID.Equals(GoverningTokenID()) {
spentCoin := NewSpentCoinState(input.PrevHash, prevTXHeight)
spentCoin.items[input.PrevIndex] = block.Index spentCoin.items[input.PrevIndex] = block.Index
if err = cache.PutSpentCoinState(input.PrevHash, spentCoin); err != nil {
return err
}
if err = processTXWithValidatorsSubtract(&prevTXOutput, account, cache); err != nil { if err = processTXWithValidatorsSubtract(&prevTXOutput, account, cache); err != nil {
return err return err
} }
@ -482,6 +481,14 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
return err return err
} }
} }
if err = cache.PutUnspentCoinState(prevHash, unspent); err != nil {
return err
}
if oldSpentCoinLen != len(spentCoin.items) {
if err = cache.PutSpentCoinState(prevHash, spentCoin); err != nil {
return err
}
}
} }
// Process the underlying type of the TX. // Process the underlying type of the TX.
@ -517,18 +524,34 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
// Remove claimed NEO from spent coins making it unavalaible for // Remove claimed NEO from spent coins making it unavalaible for
// additional claims. // additional claims.
for _, input := range t.Claims { for _, input := range t.Claims {
scs, err := cache.GetSpentCoinsOrNew(input.PrevHash) scs, err := cache.GetSpentCoinState(input.PrevHash)
if err == nil {
_, ok := scs.items[input.PrevIndex]
if !ok {
err = errors.New("no spent coin state")
}
}
if err != nil { if err != nil {
// We can't really do anything about it
// as it's a transaction in a signed block.
bc.log.Warn("DOUBLE CLAIM",
zap.String("PrevHash", input.PrevHash.StringLE()),
zap.Uint16("PrevIndex", input.PrevIndex),
zap.String("tx", tx.Hash().StringLE()),
zap.Uint32("block", block.Index),
)
// "Strict" mode.
if bc.config.VerifyTransactions {
return err return err
} }
if scs.txHash == input.PrevHash { break
// Existing scs. }
delete(scs.items, input.PrevIndex) delete(scs.items, input.PrevIndex)
if len(scs.items) > 0 {
if err = cache.PutSpentCoinState(input.PrevHash, scs); err != nil { if err = cache.PutSpentCoinState(input.PrevHash, scs); err != nil {
return err return err
} }
} else { } else {
// Uninitialized, new, forget about it.
if err = cache.DeleteSpentCoinState(input.PrevHash); err != nil { if err = cache.DeleteSpentCoinState(input.PrevHash); err != nil {
return err return err
} }
@ -986,27 +1009,34 @@ func (bc *Blockchain) GetConfig() config.ProtocolConfiguration {
return bc.config return bc.config
} }
// References returns a map with input coin reference (prevhash and index) as key // References maps transaction's inputs into a slice of InOuts, effectively
// and transaction output as value from a transaction t. // joining each Input with the corresponding Output.
// @TODO: unfortunately we couldn't attach this method to the Transaction struct in the // @TODO: unfortunately we couldn't attach this method to the Transaction struct in the
// transaction package because of a import cycle problem. Perhaps we should think to re-design // transaction package because of a import cycle problem. Perhaps we should think to re-design
// the code base to avoid this situation. // the code base to avoid this situation.
func (bc *Blockchain) References(t *transaction.Transaction) map[transaction.Input]*transaction.Output { func (bc *Blockchain) References(t *transaction.Transaction) ([]transaction.InOut, error) {
references := make(map[transaction.Input]*transaction.Output) return bc.references(t.Inputs)
}
for prevHash, inputs := range t.GroupInputsByPrevHash() { // references is an internal implementation of References that operates directly
// on a slice of Input.
func (bc *Blockchain) references(ins []transaction.Input) ([]transaction.InOut, error) {
references := make([]transaction.InOut, 0, len(ins))
for _, inputs := range transaction.GroupInputsByPrevHash(ins) {
prevHash := inputs[0].PrevHash
tx, _, err := bc.dao.GetTransaction(prevHash) tx, _, err := bc.dao.GetTransaction(prevHash)
if err != nil { if err != nil {
return nil return nil, errors.New("bad input reference")
} }
for _, in := range inputs { for _, in := range inputs {
if int(in.PrevIndex) > len(tx.Outputs)-1 { if int(in.PrevIndex) > len(tx.Outputs)-1 {
return nil return nil, errors.New("bad input reference")
} }
references[*in] = &tx.Outputs[in.PrevIndex] references = append(references, transaction.InOut{In: *in, Out: tx.Outputs[in.PrevIndex]})
} }
} }
return references return references, nil
} }
// FeePerByte returns network fee divided by the size of the transaction. // FeePerByte returns network fee divided by the size of the transaction.
@ -1017,16 +1047,20 @@ func (bc *Blockchain) FeePerByte(t *transaction.Transaction) util.Fixed8 {
// NetworkFee returns network fee. // NetworkFee returns network fee.
func (bc *Blockchain) NetworkFee(t *transaction.Transaction) util.Fixed8 { func (bc *Blockchain) NetworkFee(t *transaction.Transaction) util.Fixed8 {
inputAmount := util.Fixed8FromInt64(0) inputAmount := util.Fixed8FromInt64(0)
for _, txOutput := range bc.References(t) { refs, err := bc.References(t)
if txOutput.AssetID == UtilityTokenID() { if err != nil {
inputAmount.Add(txOutput.Amount) return inputAmount
}
for i := range refs {
if refs[i].Out.AssetID == UtilityTokenID() {
inputAmount = inputAmount.Add(refs[i].Out.Amount)
} }
} }
outputAmount := util.Fixed8FromInt64(0) outputAmount := util.Fixed8FromInt64(0)
for _, txOutput := range t.Outputs { for _, txOutput := range t.Outputs {
if txOutput.AssetID == UtilityTokenID() { if txOutput.AssetID == UtilityTokenID() {
outputAmount.Add(txOutput.Amount) outputAmount = outputAmount.Add(txOutput.Amount)
} }
} }
@ -1087,7 +1121,7 @@ func (bc *Blockchain) verifyTx(t *transaction.Transaction, block *block.Block) e
if io.GetVarSize(t) > transaction.MaxTransactionSize { if io.GetVarSize(t) > transaction.MaxTransactionSize {
return errors.Errorf("invalid transaction size = %d. It shoud be less then MaxTransactionSize = %d", io.GetVarSize(t), transaction.MaxTransactionSize) return errors.Errorf("invalid transaction size = %d. It shoud be less then MaxTransactionSize = %d", io.GetVarSize(t), transaction.MaxTransactionSize)
} }
if ok := bc.verifyInputs(t); !ok { if transaction.HaveDuplicateInputs(t.Inputs) {
return errors.New("invalid transaction's inputs") return errors.New("invalid transaction's inputs")
} }
if block == nil { if block == nil {
@ -1111,6 +1145,16 @@ func (bc *Blockchain) verifyTx(t *transaction.Transaction, block *block.Block) e
} }
} }
if t.Type == transaction.ClaimType {
claim := t.Data.(*transaction.ClaimTX)
if transaction.HaveDuplicateInputs(claim.Claims) {
return errors.New("duplicate claims")
}
if bc.dao.IsDoubleClaim(claim) {
return errors.New("double claim")
}
}
return bc.verifyTxWitnesses(t, block) return bc.verifyTxWitnesses(t, block)
} }
@ -1130,6 +1174,12 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool {
if bc.dao.IsDoubleSpend(t) { if bc.dao.IsDoubleSpend(t) {
return false return false
} }
if t.Type == transaction.ClaimType {
claim := t.Data.(*transaction.ClaimTX)
if bc.dao.IsDoubleClaim(claim) {
return false
}
}
for i := range t.Scripts { for i := range t.Scripts {
if !vm.IsStandardContract(t.Scripts[i].VerificationScript) { if !vm.IsStandardContract(t.Scripts[i].VerificationScript) {
recheckWitness = true recheckWitness = true
@ -1189,18 +1239,6 @@ func (bc *Blockchain) PoolTx(t *transaction.Transaction) error {
return nil return nil
} }
func (bc *Blockchain) verifyInputs(t *transaction.Transaction) bool {
for i := 1; i < len(t.Inputs); i++ {
for j := 0; j < i; j++ {
if t.Inputs[i].PrevHash == t.Inputs[j].PrevHash && t.Inputs[i].PrevIndex == t.Inputs[j].PrevIndex {
return false
}
}
}
return true
}
func (bc *Blockchain) verifyOutputs(t *transaction.Transaction) error { func (bc *Blockchain) verifyOutputs(t *transaction.Transaction) error {
for assetID, outputs := range t.GroupOutputByAssetID() { for assetID, outputs := range t.GroupOutputByAssetID() {
assetState := bc.GetAssetState(assetID) assetState := bc.GetAssetState(assetID)
@ -1286,14 +1324,14 @@ func (bc *Blockchain) GetTransactionResults(t *transaction.Transaction) []*trans
var results []*transaction.Result var results []*transaction.Result
tempGroupResult := make(map[util.Uint256]util.Fixed8) tempGroupResult := make(map[util.Uint256]util.Fixed8)
references := bc.References(t) references, err := bc.References(t)
if references == nil { if err != nil {
return nil return nil
} }
for _, output := range references { for _, inout := range references {
tempResults = append(tempResults, &transaction.Result{ tempResults = append(tempResults, &transaction.Result{
AssetID: output.AssetID, AssetID: inout.Out.AssetID,
Amount: output.Amount, Amount: inout.Out.Amount,
}) })
} }
for _, output := range t.Outputs { for _, output := range t.Outputs {
@ -1323,39 +1361,6 @@ func (bc *Blockchain) GetTransactionResults(t *transaction.Transaction) []*trans
return results return results
} }
// GetScriptHashesForVerifyingClaim returns all ScriptHashes of Claim transaction
// which has a different implementation from generic GetScriptHashesForVerifying.
func (bc *Blockchain) GetScriptHashesForVerifyingClaim(t *transaction.Transaction) ([]util.Uint160, error) {
// Avoiding duplicates.
hashmap := make(map[util.Uint160]bool)
claim := t.Data.(*transaction.ClaimTX)
clGroups := make(map[util.Uint256][]*transaction.Input)
for _, in := range claim.Claims {
clGroups[in.PrevHash] = append(clGroups[in.PrevHash], in)
}
for group, inputs := range clGroups {
refTx, _, err := bc.dao.GetTransaction(group)
if err != nil {
return nil, err
}
for _, input := range inputs {
if len(refTx.Outputs) <= int(input.PrevIndex) {
return nil, fmt.Errorf("wrong PrevIndex reference")
}
hashmap[refTx.Outputs[input.PrevIndex].ScriptHash] = true
}
}
if len(hashmap) > 0 {
hashes := make([]util.Uint160, 0, len(hashmap))
for k := range hashmap {
hashes = append(hashes, k)
}
return hashes, nil
}
return nil, fmt.Errorf("no hashes found")
}
//GetStandByValidators returns validators from the configuration. //GetStandByValidators returns validators from the configuration.
func (bc *Blockchain) GetStandByValidators() (keys.PublicKeys, error) { func (bc *Blockchain) GetStandByValidators() (keys.PublicKeys, error) {
return getValidators(bc.config) return getValidators(bc.config)
@ -1507,19 +1512,13 @@ func processEnrollmentTX(dao *cachedDao, tx *transaction.EnrollmentTX) error {
// to verify whether the transaction is bonafide or not. // to verify whether the transaction is bonafide or not.
// Golang implementation of GetScriptHashesForVerifying method in C# (https://github.com/neo-project/neo/blob/master/neo/Network/P2P/Payloads/Transaction.cs#L190) // Golang implementation of GetScriptHashesForVerifying method in C# (https://github.com/neo-project/neo/blob/master/neo/Network/P2P/Payloads/Transaction.cs#L190)
func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([]util.Uint160, error) { func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([]util.Uint160, error) {
if t.Type == transaction.ClaimType { references, err := bc.References(t)
return bc.GetScriptHashesForVerifyingClaim(t) if err != nil {
} return nil, err
references := bc.References(t)
if references == nil {
return nil, errors.New("invalid inputs")
} }
hashes := make(map[util.Uint160]bool) hashes := make(map[util.Uint160]bool)
for _, i := range t.Inputs { for i := range references {
h := references[i].ScriptHash hashes[references[i].Out.ScriptHash] = true
if _, ok := hashes[h]; !ok {
hashes[h] = true
}
} }
for _, a := range t.Attributes { for _, a := range t.Attributes {
if a.Usage == transaction.Script { if a.Usage == transaction.Script {
@ -1547,6 +1546,20 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([
} }
} }
} }
switch t.Type {
case transaction.ClaimType:
claim := t.Data.(*transaction.ClaimTX)
refs, err := bc.references(claim.Claims)
if err != nil {
return nil, err
}
for i := range refs {
hashes[refs[i].Out.ScriptHash] = true
}
case transaction.EnrollmentType:
etx := t.Data.(*transaction.EnrollmentTX)
hashes[etx.PublicKey.GetScriptHash()] = true
}
// convert hashes to []util.Uint160 // convert hashes to []util.Uint160
hashesResult := make([]util.Uint160, 0, len(hashes)) hashesResult := make([]util.Uint160, 0, len(hashes))
for h := range hashes { for h := range hashes {

View file

@ -39,7 +39,7 @@ type Blockchainer interface {
GetTestVM() (*vm.VM, storage.Store) GetTestVM() (*vm.VM, storage.Store)
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
GetUnspentCoinState(util.Uint256) *UnspentCoinState GetUnspentCoinState(util.Uint256) *UnspentCoinState
References(t *transaction.Transaction) map[transaction.Input]*transaction.Output References(t *transaction.Transaction) ([]transaction.InOut, error)
mempool.Feer // fee interface mempool.Feer // fee interface
PoolTx(*transaction.Transaction) error PoolTx(*transaction.Transaction) error
VerifyTx(*transaction.Transaction, *block.Block) error VerifyTx(*transaction.Transaction, *block.Block) error

View file

@ -175,15 +175,13 @@ func (dao *dao) PutUnspentCoinState(hash util.Uint256, ucs *UnspentCoinState) er
// -- start spent coins. // -- start spent coins.
// GetSpentCoinsOrNew returns spent coins from store. // GetSpentCoinsOrNew returns spent coins from store.
func (dao *dao) GetSpentCoinsOrNew(hash util.Uint256) (*SpentCoinState, error) { func (dao *dao) GetSpentCoinsOrNew(hash util.Uint256, height uint32) (*SpentCoinState, error) {
spent, err := dao.GetSpentCoinState(hash) spent, err := dao.GetSpentCoinState(hash)
if err != nil { if err != nil {
if err != storage.ErrKeyNotFound { if err != storage.ErrKeyNotFound {
return nil, err return nil, err
} }
spent = &SpentCoinState{ spent = NewSpentCoinState(hash, height)
items: make(map[uint16]uint32),
}
} }
return spent, nil return spent, nil
} }
@ -549,7 +547,8 @@ func (dao *dao) IsDoubleSpend(tx *transaction.Transaction) bool {
if len(tx.Inputs) == 0 { if len(tx.Inputs) == 0 {
return false return false
} }
for prevHash, inputs := range tx.GroupInputsByPrevHash() { for _, inputs := range transaction.GroupInputsByPrevHash(tx.Inputs) {
prevHash := inputs[0].PrevHash
unspent, err := dao.GetUnspentCoinState(prevHash) unspent, err := dao.GetUnspentCoinState(prevHash)
if err != nil { if err != nil {
return false return false
@ -563,6 +562,27 @@ func (dao *dao) IsDoubleSpend(tx *transaction.Transaction) bool {
return false return false
} }
// IsDoubleClaim verifies that given claim inputs are not already claimed by another tx.
func (dao *dao) IsDoubleClaim(claim *transaction.ClaimTX) bool {
if len(claim.Claims) == 0 {
return false
}
for _, inputs := range transaction.GroupInputsByPrevHash(claim.Claims) {
prevHash := inputs[0].PrevHash
scs, err := dao.GetSpentCoinState(prevHash)
if err != nil {
return true
}
for _, input := range inputs {
_, ok := scs.items[input.PrevIndex]
if !ok {
return true
}
}
}
return false
}
// Persist flushes all the changes made into the (supposedly) persistent // Persist flushes all the changes made into the (supposedly) persistent
// underlying store. // underlying store.
func (dao *dao) Persist() (int, error) { func (dao *dao) Persist() (int, error) {

View file

@ -124,7 +124,7 @@ func TestPutGetUnspentCoinState(t *testing.T) {
func TestGetSpentCoinStateOrNew_New(t *testing.T) { func TestGetSpentCoinStateOrNew_New(t *testing.T) {
dao := newDao(storage.NewMemoryStore()) dao := newDao(storage.NewMemoryStore())
hash := random.Uint256() hash := random.Uint256()
spentCoinState, err := dao.GetSpentCoinsOrNew(hash) spentCoinState, err := dao.GetSpentCoinsOrNew(hash, 1)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, spentCoinState) require.NotNil(t, spentCoinState)
} }

View file

@ -34,12 +34,6 @@ const (
DefaultAssetLifetime = 1 + BlocksPerYear DefaultAssetLifetime = 1 + BlocksPerYear
) )
// txInOut is used to pushed one key-value pair from References() onto the stack.
type txInOut struct {
in transaction.Input
out transaction.Output
}
// headerGetVersion returns version from the header. // headerGetVersion returns version from the header.
func (ic *interopContext) headerGetVersion(v *vm.VM) error { func (ic *interopContext) headerGetVersion(v *vm.VM) error {
header, err := popHeaderFromVM(v) header, err := popHeaderFromVM(v)
@ -141,14 +135,16 @@ func (ic *interopContext) txGetReferences(v *vm.VM) error {
if !ok { if !ok {
return fmt.Errorf("type mismatch: %T is not a Transaction", txInterface) return fmt.Errorf("type mismatch: %T is not a Transaction", txInterface)
} }
refs := ic.bc.References(tx) refs, err := ic.bc.References(tx)
if err != nil {
return err
}
if len(refs) > vm.MaxArraySize { if len(refs) > vm.MaxArraySize {
return errors.New("too many references") return errors.New("too many references")
} }
stackrefs := make([]vm.StackItem, 0, len(refs)) stackrefs := make([]vm.StackItem, 0, len(refs))
for _, k := range tx.Inputs { for _, tio := range refs {
tio := txInOut{k, *refs[k]}
stackrefs = append(stackrefs, vm.NewInteropItem(tio)) stackrefs = append(stackrefs, vm.NewInteropItem(tio))
} }
v.Estack().PushVal(stackrefs) v.Estack().PushVal(stackrefs)
@ -243,11 +239,11 @@ func popInputFromVM(v *vm.VM) (*transaction.Input, error) {
inInterface := v.Estack().Pop().Value() inInterface := v.Estack().Pop().Value()
input, ok := inInterface.(*transaction.Input) input, ok := inInterface.(*transaction.Input)
if !ok { if !ok {
txio, ok := inInterface.(txInOut) txio, ok := inInterface.(transaction.InOut)
if !ok { if !ok {
return nil, fmt.Errorf("type mismatch: %T is not an Input or txInOut", inInterface) return nil, fmt.Errorf("type mismatch: %T is not an Input or InOut", inInterface)
} }
input = &txio.in input = &txio.In
} }
return input, nil return input, nil
} }
@ -277,11 +273,11 @@ func popOutputFromVM(v *vm.VM) (*transaction.Output, error) {
outInterface := v.Estack().Pop().Value() outInterface := v.Estack().Pop().Value()
output, ok := outInterface.(*transaction.Output) output, ok := outInterface.(*transaction.Output)
if !ok { if !ok {
txio, ok := outInterface.(txInOut) txio, ok := outInterface.(transaction.InOut)
if !ok { if !ok {
return nil, fmt.Errorf("type mismatch: %T is not an Output or txInOut", outInterface) return nil, fmt.Errorf("type mismatch: %T is not an Output or InOut", outInterface)
} }
output = &txio.out output = &txio.Out
} }
return output, nil return output, nil
} }

View file

@ -6,7 +6,7 @@ import (
// ClaimTX represents a claim transaction. // ClaimTX represents a claim transaction.
type ClaimTX struct { type ClaimTX struct {
Claims []*Input Claims []Input
} }
// DecodeBinary implements Serializable interface. // DecodeBinary implements Serializable interface.

View file

@ -0,0 +1,8 @@
package transaction
// InOut represents an Input bound to its corresponding Output which is a useful
// combination for many purposes.
type InOut struct {
In Input
Out Output
}

View file

@ -1,6 +1,8 @@
package transaction package transaction
import ( import (
"sort"
"github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
) )
@ -25,3 +27,62 @@ func (in *Input) EncodeBinary(bw *io.BinWriter) {
bw.WriteBytes(in.PrevHash[:]) bw.WriteBytes(in.PrevHash[:])
bw.WriteU16LE(in.PrevIndex) bw.WriteU16LE(in.PrevIndex)
} }
// MapInputsToSorted maps given slice of inputs into a new slice of pointers
// to inputs sorted by their PrevHash and PrevIndex.
func MapInputsToSorted(ins []Input) []*Input {
ptrs := make([]*Input, len(ins))
for i := range ins {
ptrs[i] = &ins[i]
}
sort.Slice(ptrs, func(i, j int) bool {
hashcmp := ptrs[i].PrevHash.CompareTo(ptrs[j].PrevHash)
if hashcmp == 0 {
return ptrs[i].PrevIndex < ptrs[j].PrevIndex
}
return hashcmp < 0
})
return ptrs
}
// GroupInputsByPrevHash groups all TX inputs by their previous hash into
// several slices (which actually are subslices of one new slice with pointers).
// Each of these slices contains at least one element.
func GroupInputsByPrevHash(ins []Input) [][]*Input {
if len(ins) == 0 {
return nil
}
ptrs := MapInputsToSorted(ins)
var first int
res := make([][]*Input, 0)
currentHash := ptrs[0].PrevHash
for i := range ptrs {
if !currentHash.Equals(ptrs[i].PrevHash) {
res = append(res, ptrs[first:i])
first = i
currentHash = ptrs[i].PrevHash
}
}
res = append(res, ptrs[first:])
return res
}
// HaveDuplicateInputs checks inputs for duplicates and returns true if there are
// any.
func HaveDuplicateInputs(ins []Input) bool {
if len(ins) < 2 {
return false
}
if len(ins) == 2 {
return ins[0] == ins[1]
}
ptrs := MapInputsToSorted(ins)
for i := 1; i < len(ptrs); i++ {
if *ptrs[i] == *ptrs[i-1] {
return true
}
}
return false
}

View file

@ -0,0 +1,144 @@
package transaction
import (
"testing"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGroupInputsByPrevHash0(t *testing.T) {
inputs := make([]Input, 0)
res := GroupInputsByPrevHash(inputs)
require.Equal(t, 0, len(res))
}
func TestGroupInputsByPrevHash1(t *testing.T) {
inputs := make([]Input, 0)
hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af")
require.NoError(t, err)
inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42})
res := GroupInputsByPrevHash(inputs)
require.Equal(t, 1, len(res))
require.Equal(t, 1, len(res[0]))
assert.Equal(t, hash, res[0][0].PrevHash)
assert.Equal(t, uint16(42), res[0][0].PrevIndex)
}
func TestGroupInputsByPrevHashMany(t *testing.T) {
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
require.NoError(t, err)
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
require.NoError(t, err)
hash3, err := util.Uint256DecodeStringBE("caa41245c3e48ddc13dabe989ba8fbc59418e9228fef9efb62855b0b17d7448b")
require.NoError(t, err)
inputs := make([]Input, 0)
for i := 0; i < 10; i++ {
inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)})
inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)})
inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)})
}
for i := 15; i < 20; i++ {
inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)})
}
for i := 10; i < 15; i++ {
inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)})
inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)})
}
seen := make(map[uint16]bool)
res := GroupInputsByPrevHash(inputs)
require.Equal(t, 3, len(res))
assert.Equal(t, hash2, res[0][0].PrevHash)
assert.Equal(t, 15, len(res[0]))
for i := range res[0] {
assert.Equal(t, res[0][i].PrevHash, res[0][0].PrevHash)
assert.Equal(t, false, seen[res[0][i].PrevIndex])
seen[res[0][i].PrevIndex] = true
}
seen = make(map[uint16]bool)
assert.Equal(t, hash1, res[1][0].PrevHash)
assert.Equal(t, 10, len(res[1]))
for i := range res[1] {
assert.Equal(t, res[1][i].PrevHash, res[1][0].PrevHash)
assert.Equal(t, false, seen[res[1][i].PrevIndex])
seen[res[1][i].PrevIndex] = true
}
seen = make(map[uint16]bool)
assert.Equal(t, hash3, res[2][0].PrevHash)
assert.Equal(t, 20, len(res[2]))
for i := range res[2] {
assert.Equal(t, res[2][i].PrevHash, res[2][0].PrevHash)
assert.Equal(t, false, seen[res[2][i].PrevIndex])
seen[res[2][i].PrevIndex] = true
}
}
func TestHaveDuplicateInputs0(t *testing.T) {
inputs := make([]Input, 0)
require.False(t, HaveDuplicateInputs(inputs))
}
func TestHaveDuplicateInputs1(t *testing.T) {
inputs := make([]Input, 0)
hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af")
require.NoError(t, err)
inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42})
require.False(t, HaveDuplicateInputs(inputs))
}
func TestHaveDuplicateInputs2True(t *testing.T) {
inputs := make([]Input, 0)
hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af")
require.NoError(t, err)
inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42})
inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42})
require.True(t, HaveDuplicateInputs(inputs))
}
func TestHaveDuplicateInputs2FalseInd(t *testing.T) {
inputs := make([]Input, 0)
hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af")
require.NoError(t, err)
inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42})
inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 41})
require.False(t, HaveDuplicateInputs(inputs))
}
func TestHaveDuplicateInputs2FalseHash(t *testing.T) {
inputs := make([]Input, 0)
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
require.NoError(t, err)
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
require.NoError(t, err)
inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: 42})
inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: 42})
require.False(t, HaveDuplicateInputs(inputs))
}
func TestHaveDuplicateInputsMFalse(t *testing.T) {
inputs := make([]Input, 0)
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
require.NoError(t, err)
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
require.NoError(t, err)
for i := 0; i < 10; i++ {
inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)})
inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)})
}
require.False(t, HaveDuplicateInputs(inputs))
}
func TestHaveDuplicateInputsMTrue(t *testing.T) {
inputs := make([]Input, 0)
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
require.NoError(t, err)
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
require.NoError(t, err)
for i := 0; i < 10; i++ {
inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)})
inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)})
}
inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: 0})
require.True(t, HaveDuplicateInputs(inputs))
}

View file

@ -187,16 +187,6 @@ func (t *Transaction) createHash() error {
return nil return nil
} }
// GroupInputsByPrevHash groups all TX inputs by their previous hash.
func (t *Transaction) GroupInputsByPrevHash() map[util.Uint256][]*Input {
m := make(map[util.Uint256][]*Input)
for i := range t.Inputs {
hash := t.Inputs[i].PrevHash
m[hash] = append(m[hash], &t.Inputs[i])
}
return m
}
// GroupOutputByAssetID groups all TX outputs by their assetID. // GroupOutputByAssetID groups all TX outputs by their assetID.
func (t Transaction) GroupOutputByAssetID() map[util.Uint256][]*Output { func (t Transaction) GroupOutputByAssetID() map[util.Uint256][]*Output {
m := make(map[util.Uint256][]*Output) m := make(map[util.Uint256][]*Output)

View file

@ -33,7 +33,7 @@ func (chain testChain) GetConfig() config.ProtocolConfiguration {
panic("TODO") panic("TODO")
} }
func (chain testChain) References(t *transaction.Transaction) map[transaction.Input]*transaction.Output { func (chain testChain) References(t *transaction.Transaction) ([]transaction.InOut, error) {
panic("TODO") panic("TODO")
} }