diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index d22326398..32361f61f 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -431,20 +431,23 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } // Process TX inputs that are grouped by previous hash. - for prevHash, inputs := range tx.GroupInputsByPrevHash() { + for _, inputs := range transaction.GroupInputsByPrevHash(tx.Inputs) { + prevHash := inputs[0].PrevHash prevTX, prevTXHeight, err := bc.dao.GetTransaction(prevHash) if err != nil { return fmt.Errorf("could not find previous TX: %s", prevHash) } + unspent, err := cache.GetUnspentCoinStateOrNew(prevHash) + if err != nil { + return err + } + spentCoin, err := cache.GetSpentCoinsOrNew(prevHash, prevTXHeight) + if err != nil { + return err + } + oldSpentCoinLen := len(spentCoin.items) for _, input := range inputs { - unspent, err := cache.GetUnspentCoinStateOrNew(input.PrevHash) - if err != nil { - return err - } unspent.states[input.PrevIndex] = state.CoinSpent - if err = cache.PutUnspentCoinState(input.PrevHash, unspent); err != nil { - return err - } prevTXOutput := prevTX.Outputs[input.PrevIndex] account, err := cache.GetAccountStateOrNew(prevTXOutput.ScriptHash) if err != nil { @@ -452,11 +455,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } if prevTXOutput.AssetID.Equals(GoverningTokenID()) { - spentCoin := NewSpentCoinState(input.PrevHash, prevTXHeight) spentCoin.items[input.PrevIndex] = block.Index - if err = cache.PutSpentCoinState(input.PrevHash, spentCoin); err != nil { - return err - } if err = processTXWithValidatorsSubtract(&prevTXOutput, account, cache); err != nil { return err } @@ -482,6 +481,14 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { return err } } + if err = cache.PutUnspentCoinState(prevHash, unspent); err != nil { + return err + } + if oldSpentCoinLen != len(spentCoin.items) { + if err = cache.PutSpentCoinState(prevHash, spentCoin); err != nil { + return err + } + } } // Process the underlying type of the TX. @@ -517,18 +524,34 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { // Remove claimed NEO from spent coins making it unavalaible for // additional claims. for _, input := range t.Claims { - scs, err := cache.GetSpentCoinsOrNew(input.PrevHash) - if err != nil { - return err + scs, err := cache.GetSpentCoinState(input.PrevHash) + if err == nil { + _, ok := scs.items[input.PrevIndex] + if !ok { + err = errors.New("no spent coin state") + } } - if scs.txHash == input.PrevHash { - // Existing scs. - delete(scs.items, input.PrevIndex) + if err != nil { + // We can't really do anything about it + // as it's a transaction in a signed block. + bc.log.Warn("DOUBLE CLAIM", + zap.String("PrevHash", input.PrevHash.StringLE()), + zap.Uint16("PrevIndex", input.PrevIndex), + zap.String("tx", tx.Hash().StringLE()), + zap.Uint32("block", block.Index), + ) + // "Strict" mode. + if bc.config.VerifyTransactions { + return err + } + break + } + delete(scs.items, input.PrevIndex) + if len(scs.items) > 0 { if err = cache.PutSpentCoinState(input.PrevHash, scs); err != nil { return err } } else { - // Uninitialized, new, forget about it. if err = cache.DeleteSpentCoinState(input.PrevHash); err != nil { return err } @@ -986,27 +1009,34 @@ func (bc *Blockchain) GetConfig() config.ProtocolConfiguration { return bc.config } -// References returns a map with input coin reference (prevhash and index) as key -// and transaction output as value from a transaction t. +// References maps transaction's inputs into a slice of InOuts, effectively +// joining each Input with the corresponding Output. // @TODO: unfortunately we couldn't attach this method to the Transaction struct in the // transaction package because of a import cycle problem. Perhaps we should think to re-design // the code base to avoid this situation. -func (bc *Blockchain) References(t *transaction.Transaction) map[transaction.Input]*transaction.Output { - references := make(map[transaction.Input]*transaction.Output) +func (bc *Blockchain) References(t *transaction.Transaction) ([]transaction.InOut, error) { + return bc.references(t.Inputs) +} - for prevHash, inputs := range t.GroupInputsByPrevHash() { +// references is an internal implementation of References that operates directly +// on a slice of Input. +func (bc *Blockchain) references(ins []transaction.Input) ([]transaction.InOut, error) { + references := make([]transaction.InOut, 0, len(ins)) + + for _, inputs := range transaction.GroupInputsByPrevHash(ins) { + prevHash := inputs[0].PrevHash tx, _, err := bc.dao.GetTransaction(prevHash) if err != nil { - return nil + return nil, errors.New("bad input reference") } for _, in := range inputs { if int(in.PrevIndex) > len(tx.Outputs)-1 { - return nil + return nil, errors.New("bad input reference") } - references[*in] = &tx.Outputs[in.PrevIndex] + references = append(references, transaction.InOut{In: *in, Out: tx.Outputs[in.PrevIndex]}) } } - return references + return references, nil } // FeePerByte returns network fee divided by the size of the transaction. @@ -1017,16 +1047,20 @@ func (bc *Blockchain) FeePerByte(t *transaction.Transaction) util.Fixed8 { // NetworkFee returns network fee. func (bc *Blockchain) NetworkFee(t *transaction.Transaction) util.Fixed8 { inputAmount := util.Fixed8FromInt64(0) - for _, txOutput := range bc.References(t) { - if txOutput.AssetID == UtilityTokenID() { - inputAmount.Add(txOutput.Amount) + refs, err := bc.References(t) + if err != nil { + return inputAmount + } + for i := range refs { + if refs[i].Out.AssetID == UtilityTokenID() { + inputAmount = inputAmount.Add(refs[i].Out.Amount) } } outputAmount := util.Fixed8FromInt64(0) for _, txOutput := range t.Outputs { if txOutput.AssetID == UtilityTokenID() { - outputAmount.Add(txOutput.Amount) + outputAmount = outputAmount.Add(txOutput.Amount) } } @@ -1087,7 +1121,7 @@ func (bc *Blockchain) verifyTx(t *transaction.Transaction, block *block.Block) e if io.GetVarSize(t) > transaction.MaxTransactionSize { return errors.Errorf("invalid transaction size = %d. It shoud be less then MaxTransactionSize = %d", io.GetVarSize(t), transaction.MaxTransactionSize) } - if ok := bc.verifyInputs(t); !ok { + if transaction.HaveDuplicateInputs(t.Inputs) { return errors.New("invalid transaction's inputs") } if block == nil { @@ -1111,6 +1145,16 @@ func (bc *Blockchain) verifyTx(t *transaction.Transaction, block *block.Block) e } } + if t.Type == transaction.ClaimType { + claim := t.Data.(*transaction.ClaimTX) + if transaction.HaveDuplicateInputs(claim.Claims) { + return errors.New("duplicate claims") + } + if bc.dao.IsDoubleClaim(claim) { + return errors.New("double claim") + } + } + return bc.verifyTxWitnesses(t, block) } @@ -1130,6 +1174,12 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { if bc.dao.IsDoubleSpend(t) { return false } + if t.Type == transaction.ClaimType { + claim := t.Data.(*transaction.ClaimTX) + if bc.dao.IsDoubleClaim(claim) { + return false + } + } for i := range t.Scripts { if !vm.IsStandardContract(t.Scripts[i].VerificationScript) { recheckWitness = true @@ -1189,18 +1239,6 @@ func (bc *Blockchain) PoolTx(t *transaction.Transaction) error { return nil } -func (bc *Blockchain) verifyInputs(t *transaction.Transaction) bool { - for i := 1; i < len(t.Inputs); i++ { - for j := 0; j < i; j++ { - if t.Inputs[i].PrevHash == t.Inputs[j].PrevHash && t.Inputs[i].PrevIndex == t.Inputs[j].PrevIndex { - return false - } - } - } - - return true -} - func (bc *Blockchain) verifyOutputs(t *transaction.Transaction) error { for assetID, outputs := range t.GroupOutputByAssetID() { assetState := bc.GetAssetState(assetID) @@ -1286,14 +1324,14 @@ func (bc *Blockchain) GetTransactionResults(t *transaction.Transaction) []*trans var results []*transaction.Result tempGroupResult := make(map[util.Uint256]util.Fixed8) - references := bc.References(t) - if references == nil { + references, err := bc.References(t) + if err != nil { return nil } - for _, output := range references { + for _, inout := range references { tempResults = append(tempResults, &transaction.Result{ - AssetID: output.AssetID, - Amount: output.Amount, + AssetID: inout.Out.AssetID, + Amount: inout.Out.Amount, }) } for _, output := range t.Outputs { @@ -1323,39 +1361,6 @@ func (bc *Blockchain) GetTransactionResults(t *transaction.Transaction) []*trans return results } -// GetScriptHashesForVerifyingClaim returns all ScriptHashes of Claim transaction -// which has a different implementation from generic GetScriptHashesForVerifying. -func (bc *Blockchain) GetScriptHashesForVerifyingClaim(t *transaction.Transaction) ([]util.Uint160, error) { - // Avoiding duplicates. - hashmap := make(map[util.Uint160]bool) - - claim := t.Data.(*transaction.ClaimTX) - clGroups := make(map[util.Uint256][]*transaction.Input) - for _, in := range claim.Claims { - clGroups[in.PrevHash] = append(clGroups[in.PrevHash], in) - } - for group, inputs := range clGroups { - refTx, _, err := bc.dao.GetTransaction(group) - if err != nil { - return nil, err - } - for _, input := range inputs { - if len(refTx.Outputs) <= int(input.PrevIndex) { - return nil, fmt.Errorf("wrong PrevIndex reference") - } - hashmap[refTx.Outputs[input.PrevIndex].ScriptHash] = true - } - } - if len(hashmap) > 0 { - hashes := make([]util.Uint160, 0, len(hashmap)) - for k := range hashmap { - hashes = append(hashes, k) - } - return hashes, nil - } - return nil, fmt.Errorf("no hashes found") -} - //GetStandByValidators returns validators from the configuration. func (bc *Blockchain) GetStandByValidators() (keys.PublicKeys, error) { return getValidators(bc.config) @@ -1507,19 +1512,13 @@ func processEnrollmentTX(dao *cachedDao, tx *transaction.EnrollmentTX) error { // to verify whether the transaction is bonafide or not. // Golang implementation of GetScriptHashesForVerifying method in C# (https://github.com/neo-project/neo/blob/master/neo/Network/P2P/Payloads/Transaction.cs#L190) func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([]util.Uint160, error) { - if t.Type == transaction.ClaimType { - return bc.GetScriptHashesForVerifyingClaim(t) - } - references := bc.References(t) - if references == nil { - return nil, errors.New("invalid inputs") + references, err := bc.References(t) + if err != nil { + return nil, err } hashes := make(map[util.Uint160]bool) - for _, i := range t.Inputs { - h := references[i].ScriptHash - if _, ok := hashes[h]; !ok { - hashes[h] = true - } + for i := range references { + hashes[references[i].Out.ScriptHash] = true } for _, a := range t.Attributes { if a.Usage == transaction.Script { @@ -1547,6 +1546,20 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([ } } } + switch t.Type { + case transaction.ClaimType: + claim := t.Data.(*transaction.ClaimTX) + refs, err := bc.references(claim.Claims) + if err != nil { + return nil, err + } + for i := range refs { + hashes[refs[i].Out.ScriptHash] = true + } + case transaction.EnrollmentType: + etx := t.Data.(*transaction.EnrollmentTX) + hashes[etx.PublicKey.GetScriptHash()] = true + } // convert hashes to []util.Uint160 hashesResult := make([]util.Uint160, 0, len(hashes)) for h := range hashes { diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index dae78c48b..822c80fec 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -39,7 +39,7 @@ type Blockchainer interface { GetTestVM() (*vm.VM, storage.Store) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) GetUnspentCoinState(util.Uint256) *UnspentCoinState - References(t *transaction.Transaction) map[transaction.Input]*transaction.Output + References(t *transaction.Transaction) ([]transaction.InOut, error) mempool.Feer // fee interface PoolTx(*transaction.Transaction) error VerifyTx(*transaction.Transaction, *block.Block) error diff --git a/pkg/core/dao.go b/pkg/core/dao.go index 1ff5d5e5e..d84e50816 100644 --- a/pkg/core/dao.go +++ b/pkg/core/dao.go @@ -175,15 +175,13 @@ func (dao *dao) PutUnspentCoinState(hash util.Uint256, ucs *UnspentCoinState) er // -- start spent coins. // GetSpentCoinsOrNew returns spent coins from store. -func (dao *dao) GetSpentCoinsOrNew(hash util.Uint256) (*SpentCoinState, error) { +func (dao *dao) GetSpentCoinsOrNew(hash util.Uint256, height uint32) (*SpentCoinState, error) { spent, err := dao.GetSpentCoinState(hash) if err != nil { if err != storage.ErrKeyNotFound { return nil, err } - spent = &SpentCoinState{ - items: make(map[uint16]uint32), - } + spent = NewSpentCoinState(hash, height) } return spent, nil } @@ -549,7 +547,8 @@ func (dao *dao) IsDoubleSpend(tx *transaction.Transaction) bool { if len(tx.Inputs) == 0 { return false } - for prevHash, inputs := range tx.GroupInputsByPrevHash() { + for _, inputs := range transaction.GroupInputsByPrevHash(tx.Inputs) { + prevHash := inputs[0].PrevHash unspent, err := dao.GetUnspentCoinState(prevHash) if err != nil { return false @@ -563,6 +562,27 @@ func (dao *dao) IsDoubleSpend(tx *transaction.Transaction) bool { return false } +// IsDoubleClaim verifies that given claim inputs are not already claimed by another tx. +func (dao *dao) IsDoubleClaim(claim *transaction.ClaimTX) bool { + if len(claim.Claims) == 0 { + return false + } + for _, inputs := range transaction.GroupInputsByPrevHash(claim.Claims) { + prevHash := inputs[0].PrevHash + scs, err := dao.GetSpentCoinState(prevHash) + if err != nil { + return true + } + for _, input := range inputs { + _, ok := scs.items[input.PrevIndex] + if !ok { + return true + } + } + } + return false +} + // Persist flushes all the changes made into the (supposedly) persistent // underlying store. func (dao *dao) Persist() (int, error) { diff --git a/pkg/core/dao_test.go b/pkg/core/dao_test.go index 52df535bc..3b4a30343 100644 --- a/pkg/core/dao_test.go +++ b/pkg/core/dao_test.go @@ -124,7 +124,7 @@ func TestPutGetUnspentCoinState(t *testing.T) { func TestGetSpentCoinStateOrNew_New(t *testing.T) { dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() - spentCoinState, err := dao.GetSpentCoinsOrNew(hash) + spentCoinState, err := dao.GetSpentCoinsOrNew(hash, 1) require.NoError(t, err) require.NotNil(t, spentCoinState) } diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index b3018e51c..1eb8e9b3c 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -34,12 +34,6 @@ const ( DefaultAssetLifetime = 1 + BlocksPerYear ) -// txInOut is used to pushed one key-value pair from References() onto the stack. -type txInOut struct { - in transaction.Input - out transaction.Output -} - // headerGetVersion returns version from the header. func (ic *interopContext) headerGetVersion(v *vm.VM) error { header, err := popHeaderFromVM(v) @@ -141,14 +135,16 @@ func (ic *interopContext) txGetReferences(v *vm.VM) error { if !ok { return fmt.Errorf("type mismatch: %T is not a Transaction", txInterface) } - refs := ic.bc.References(tx) + refs, err := ic.bc.References(tx) + if err != nil { + return err + } if len(refs) > vm.MaxArraySize { return errors.New("too many references") } stackrefs := make([]vm.StackItem, 0, len(refs)) - for _, k := range tx.Inputs { - tio := txInOut{k, *refs[k]} + for _, tio := range refs { stackrefs = append(stackrefs, vm.NewInteropItem(tio)) } v.Estack().PushVal(stackrefs) @@ -243,11 +239,11 @@ func popInputFromVM(v *vm.VM) (*transaction.Input, error) { inInterface := v.Estack().Pop().Value() input, ok := inInterface.(*transaction.Input) if !ok { - txio, ok := inInterface.(txInOut) + txio, ok := inInterface.(transaction.InOut) if !ok { - return nil, fmt.Errorf("type mismatch: %T is not an Input or txInOut", inInterface) + return nil, fmt.Errorf("type mismatch: %T is not an Input or InOut", inInterface) } - input = &txio.in + input = &txio.In } return input, nil } @@ -277,11 +273,11 @@ func popOutputFromVM(v *vm.VM) (*transaction.Output, error) { outInterface := v.Estack().Pop().Value() output, ok := outInterface.(*transaction.Output) if !ok { - txio, ok := outInterface.(txInOut) + txio, ok := outInterface.(transaction.InOut) if !ok { - return nil, fmt.Errorf("type mismatch: %T is not an Output or txInOut", outInterface) + return nil, fmt.Errorf("type mismatch: %T is not an Output or InOut", outInterface) } - output = &txio.out + output = &txio.Out } return output, nil } diff --git a/pkg/core/transaction/claim.go b/pkg/core/transaction/claim.go index 423a4bc79..05f9b1a16 100644 --- a/pkg/core/transaction/claim.go +++ b/pkg/core/transaction/claim.go @@ -6,7 +6,7 @@ import ( // ClaimTX represents a claim transaction. type ClaimTX struct { - Claims []*Input + Claims []Input } // DecodeBinary implements Serializable interface. diff --git a/pkg/core/transaction/inout.go b/pkg/core/transaction/inout.go new file mode 100644 index 000000000..75bc026ac --- /dev/null +++ b/pkg/core/transaction/inout.go @@ -0,0 +1,8 @@ +package transaction + +// InOut represents an Input bound to its corresponding Output which is a useful +// combination for many purposes. +type InOut struct { + In Input + Out Output +} diff --git a/pkg/core/transaction/input.go b/pkg/core/transaction/input.go index da390cfe5..3b7758ba3 100644 --- a/pkg/core/transaction/input.go +++ b/pkg/core/transaction/input.go @@ -1,6 +1,8 @@ package transaction import ( + "sort" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/util" ) @@ -25,3 +27,62 @@ func (in *Input) EncodeBinary(bw *io.BinWriter) { bw.WriteBytes(in.PrevHash[:]) bw.WriteU16LE(in.PrevIndex) } + +// MapInputsToSorted maps given slice of inputs into a new slice of pointers +// to inputs sorted by their PrevHash and PrevIndex. +func MapInputsToSorted(ins []Input) []*Input { + ptrs := make([]*Input, len(ins)) + for i := range ins { + ptrs[i] = &ins[i] + } + sort.Slice(ptrs, func(i, j int) bool { + hashcmp := ptrs[i].PrevHash.CompareTo(ptrs[j].PrevHash) + if hashcmp == 0 { + return ptrs[i].PrevIndex < ptrs[j].PrevIndex + } + return hashcmp < 0 + }) + return ptrs +} + +// GroupInputsByPrevHash groups all TX inputs by their previous hash into +// several slices (which actually are subslices of one new slice with pointers). +// Each of these slices contains at least one element. +func GroupInputsByPrevHash(ins []Input) [][]*Input { + if len(ins) == 0 { + return nil + } + + ptrs := MapInputsToSorted(ins) + var first int + res := make([][]*Input, 0) + currentHash := ptrs[0].PrevHash + + for i := range ptrs { + if !currentHash.Equals(ptrs[i].PrevHash) { + res = append(res, ptrs[first:i]) + first = i + currentHash = ptrs[i].PrevHash + } + } + res = append(res, ptrs[first:]) + return res +} + +// HaveDuplicateInputs checks inputs for duplicates and returns true if there are +// any. +func HaveDuplicateInputs(ins []Input) bool { + if len(ins) < 2 { + return false + } + if len(ins) == 2 { + return ins[0] == ins[1] + } + ptrs := MapInputsToSorted(ins) + for i := 1; i < len(ptrs); i++ { + if *ptrs[i] == *ptrs[i-1] { + return true + } + } + return false +} diff --git a/pkg/core/transaction/input_test.go b/pkg/core/transaction/input_test.go new file mode 100644 index 000000000..0375a4559 --- /dev/null +++ b/pkg/core/transaction/input_test.go @@ -0,0 +1,144 @@ +package transaction + +import ( + "testing" + + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGroupInputsByPrevHash0(t *testing.T) { + inputs := make([]Input, 0) + res := GroupInputsByPrevHash(inputs) + require.Equal(t, 0, len(res)) +} + +func TestGroupInputsByPrevHash1(t *testing.T) { + inputs := make([]Input, 0) + hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af") + require.NoError(t, err) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42}) + res := GroupInputsByPrevHash(inputs) + require.Equal(t, 1, len(res)) + require.Equal(t, 1, len(res[0])) + assert.Equal(t, hash, res[0][0].PrevHash) + assert.Equal(t, uint16(42), res[0][0].PrevIndex) +} + +func TestGroupInputsByPrevHashMany(t *testing.T) { + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + hash3, err := util.Uint256DecodeStringBE("caa41245c3e48ddc13dabe989ba8fbc59418e9228fef9efb62855b0b17d7448b") + require.NoError(t, err) + inputs := make([]Input, 0) + for i := 0; i < 10; i++ { + inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)}) + } + for i := 15; i < 20; i++ { + inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)}) + } + for i := 10; i < 15; i++ { + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)}) + } + seen := make(map[uint16]bool) + res := GroupInputsByPrevHash(inputs) + require.Equal(t, 3, len(res)) + assert.Equal(t, hash2, res[0][0].PrevHash) + assert.Equal(t, 15, len(res[0])) + for i := range res[0] { + assert.Equal(t, res[0][i].PrevHash, res[0][0].PrevHash) + assert.Equal(t, false, seen[res[0][i].PrevIndex]) + seen[res[0][i].PrevIndex] = true + } + seen = make(map[uint16]bool) + assert.Equal(t, hash1, res[1][0].PrevHash) + assert.Equal(t, 10, len(res[1])) + for i := range res[1] { + assert.Equal(t, res[1][i].PrevHash, res[1][0].PrevHash) + assert.Equal(t, false, seen[res[1][i].PrevIndex]) + seen[res[1][i].PrevIndex] = true + } + seen = make(map[uint16]bool) + assert.Equal(t, hash3, res[2][0].PrevHash) + assert.Equal(t, 20, len(res[2])) + for i := range res[2] { + assert.Equal(t, res[2][i].PrevHash, res[2][0].PrevHash) + assert.Equal(t, false, seen[res[2][i].PrevIndex]) + seen[res[2][i].PrevIndex] = true + } +} + +func TestHaveDuplicateInputs0(t *testing.T) { + inputs := make([]Input, 0) + require.False(t, HaveDuplicateInputs(inputs)) +} + +func TestHaveDuplicateInputs1(t *testing.T) { + inputs := make([]Input, 0) + hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af") + require.NoError(t, err) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42}) + require.False(t, HaveDuplicateInputs(inputs)) +} + +func TestHaveDuplicateInputs2True(t *testing.T) { + inputs := make([]Input, 0) + hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af") + require.NoError(t, err) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42}) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42}) + require.True(t, HaveDuplicateInputs(inputs)) +} + +func TestHaveDuplicateInputs2FalseInd(t *testing.T) { + inputs := make([]Input, 0) + hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af") + require.NoError(t, err) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42}) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 41}) + require.False(t, HaveDuplicateInputs(inputs)) +} + +func TestHaveDuplicateInputs2FalseHash(t *testing.T) { + inputs := make([]Input, 0) + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: 42}) + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: 42}) + require.False(t, HaveDuplicateInputs(inputs)) +} + +func TestHaveDuplicateInputsMFalse(t *testing.T) { + inputs := make([]Input, 0) + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + for i := 0; i < 10; i++ { + inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)}) + } + require.False(t, HaveDuplicateInputs(inputs)) +} + +func TestHaveDuplicateInputsMTrue(t *testing.T) { + inputs := make([]Input, 0) + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + for i := 0; i < 10; i++ { + inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)}) + } + inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: 0}) + require.True(t, HaveDuplicateInputs(inputs)) +} diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 46c0bee77..cdabfa3c0 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -187,16 +187,6 @@ func (t *Transaction) createHash() error { return nil } -// GroupInputsByPrevHash groups all TX inputs by their previous hash. -func (t *Transaction) GroupInputsByPrevHash() map[util.Uint256][]*Input { - m := make(map[util.Uint256][]*Input) - for i := range t.Inputs { - hash := t.Inputs[i].PrevHash - m[hash] = append(m[hash], &t.Inputs[i]) - } - return m -} - // GroupOutputByAssetID groups all TX outputs by their assetID. func (t Transaction) GroupOutputByAssetID() map[util.Uint256][]*Output { m := make(map[util.Uint256][]*Output) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 2759c37df..d21eb2ae6 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -33,7 +33,7 @@ func (chain testChain) GetConfig() config.ProtocolConfiguration { panic("TODO") } -func (chain testChain) References(t *transaction.Transaction) map[transaction.Input]*transaction.Output { +func (chain testChain) References(t *transaction.Transaction) ([]transaction.InOut, error) { panic("TODO") }