diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 31bb2aefa..38f8103c6 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -431,7 +431,8 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } // Process TX inputs that are grouped by previous hash. - for prevHash, inputs := range tx.GroupInputsByPrevHash() { + for _, inputs := range transaction.GroupInputsByPrevHash(tx.Inputs) { + prevHash := inputs[0].PrevHash prevTX, prevTXHeight, err := bc.dao.GetTransaction(prevHash) if err != nil { return fmt.Errorf("could not find previous TX: %s", prevHash) @@ -1016,7 +1017,8 @@ func (bc *Blockchain) GetConfig() config.ProtocolConfiguration { func (bc *Blockchain) References(t *transaction.Transaction) map[transaction.Input]*transaction.Output { references := make(map[transaction.Input]*transaction.Output) - for prevHash, inputs := range t.GroupInputsByPrevHash() { + for _, inputs := range transaction.GroupInputsByPrevHash(t.Inputs) { + prevHash := inputs[0].PrevHash tx, _, err := bc.dao.GetTransaction(prevHash) if err != nil { return nil diff --git a/pkg/core/dao.go b/pkg/core/dao.go index c5c2c1a76..cba1d59c5 100644 --- a/pkg/core/dao.go +++ b/pkg/core/dao.go @@ -547,7 +547,8 @@ func (dao *dao) IsDoubleSpend(tx *transaction.Transaction) bool { if len(tx.Inputs) == 0 { return false } - for prevHash, inputs := range tx.GroupInputsByPrevHash() { + for _, inputs := range transaction.GroupInputsByPrevHash(tx.Inputs) { + prevHash := inputs[0].PrevHash unspent, err := dao.GetUnspentCoinState(prevHash) if err != nil { return false diff --git a/pkg/core/transaction/input.go b/pkg/core/transaction/input.go index da390cfe5..b1e6b7694 100644 --- a/pkg/core/transaction/input.go +++ b/pkg/core/transaction/input.go @@ -1,6 +1,8 @@ package transaction import ( + "sort" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/util" ) @@ -25,3 +27,34 @@ func (in *Input) EncodeBinary(bw *io.BinWriter) { bw.WriteBytes(in.PrevHash[:]) bw.WriteU16LE(in.PrevIndex) } + +// GroupInputsByPrevHash groups all TX inputs by their previous hash into +// several slices (which actually are subslices of one new slice with pointers). +// Each of these slices contains at least one element. +func GroupInputsByPrevHash(ins []Input) [][]*Input { + if len(ins) == 0 { + return nil + } + + ptrs := make([]*Input, len(ins)) + for i := range ins { + ptrs[i] = &ins[i] + } + sort.Slice(ptrs, func(i, j int) bool { + return ptrs[i].PrevHash.CompareTo(ptrs[j].PrevHash) < 0 + }) + + var first int + res := make([][]*Input, 0) + currentHash := ptrs[0].PrevHash + + for i := range ptrs { + if !currentHash.Equals(ptrs[i].PrevHash) { + res = append(res, ptrs[first:i]) + first = i + currentHash = ptrs[i].PrevHash + } + } + res = append(res, ptrs[first:]) + return res +} diff --git a/pkg/core/transaction/input_test.go b/pkg/core/transaction/input_test.go new file mode 100644 index 000000000..2b3f595ef --- /dev/null +++ b/pkg/core/transaction/input_test.go @@ -0,0 +1,75 @@ +package transaction + +import ( + "testing" + + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGroupInputsByPrevHash0(t *testing.T) { + inputs := make([]Input, 0) + res := GroupInputsByPrevHash(inputs) + require.Equal(t, 0, len(res)) +} + +func TestGroupInputsByPrevHash1(t *testing.T) { + inputs := make([]Input, 0) + hash, err := util.Uint256DecodeStringLE("46168f963d6d8168a870405f66cc9e13a235791013b8ee2f90cc20a8293bd1af") + require.NoError(t, err) + inputs = append(inputs, Input{PrevHash: hash, PrevIndex: 42}) + res := GroupInputsByPrevHash(inputs) + require.Equal(t, 1, len(res)) + require.Equal(t, 1, len(res[0])) + assert.Equal(t, hash, res[0][0].PrevHash) + assert.Equal(t, uint16(42), res[0][0].PrevIndex) +} + +func TestGroupInputsByPrevHashMany(t *testing.T) { + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + hash3, err := util.Uint256DecodeStringBE("caa41245c3e48ddc13dabe989ba8fbc59418e9228fef9efb62855b0b17d7448b") + require.NoError(t, err) + inputs := make([]Input, 0) + for i := 0; i < 10; i++ { + inputs = append(inputs, Input{PrevHash: hash1, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)}) + } + for i := 15; i < 20; i++ { + inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)}) + } + for i := 10; i < 15; i++ { + inputs = append(inputs, Input{PrevHash: hash2, PrevIndex: uint16(i)}) + inputs = append(inputs, Input{PrevHash: hash3, PrevIndex: uint16(i)}) + } + seen := make(map[uint16]bool) + res := GroupInputsByPrevHash(inputs) + require.Equal(t, 3, len(res)) + assert.Equal(t, hash2, res[0][0].PrevHash) + assert.Equal(t, 15, len(res[0])) + for i := range res[0] { + assert.Equal(t, res[0][i].PrevHash, res[0][0].PrevHash) + assert.Equal(t, false, seen[res[0][i].PrevIndex]) + seen[res[0][i].PrevIndex] = true + } + seen = make(map[uint16]bool) + assert.Equal(t, hash1, res[1][0].PrevHash) + assert.Equal(t, 10, len(res[1])) + for i := range res[1] { + assert.Equal(t, res[1][i].PrevHash, res[1][0].PrevHash) + assert.Equal(t, false, seen[res[1][i].PrevIndex]) + seen[res[1][i].PrevIndex] = true + } + seen = make(map[uint16]bool) + assert.Equal(t, hash3, res[2][0].PrevHash) + assert.Equal(t, 20, len(res[2])) + for i := range res[2] { + assert.Equal(t, res[2][i].PrevHash, res[2][0].PrevHash) + assert.Equal(t, false, seen[res[2][i].PrevIndex]) + seen[res[2][i].PrevIndex] = true + } +} diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index 46c0bee77..cdabfa3c0 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -187,16 +187,6 @@ func (t *Transaction) createHash() error { return nil } -// GroupInputsByPrevHash groups all TX inputs by their previous hash. -func (t *Transaction) GroupInputsByPrevHash() map[util.Uint256][]*Input { - m := make(map[util.Uint256][]*Input) - for i := range t.Inputs { - hash := t.Inputs[i].PrevHash - m[hash] = append(m[hash], &t.Inputs[i]) - } - return m -} - // GroupOutputByAssetID groups all TX outputs by their assetID. func (t Transaction) GroupOutputByAssetID() map[util.Uint256][]*Output { m := make(map[util.Uint256][]*Output)