mpt: split HashNode in two types
First type is non-empty HashNode, and the second one is an Empty node.
This commit is contained in:
parent
b9927c39ee
commit
7f038bd465
5 changed files with 37 additions and 20 deletions
|
@ -78,7 +78,17 @@ func (b *BaseNode) invalidateCache() {
|
|||
|
||||
// encodeNodeWithType encodes node together with it's type.
|
||||
func encodeNodeWithType(n Node, w *io.BinWriter) {
|
||||
w.WriteB(byte(n.Type()))
|
||||
switch t := n.Type(); t {
|
||||
case HashT:
|
||||
hn := n.(*HashNode)
|
||||
if !hn.hashValid {
|
||||
w.WriteB(byte(EmptyT))
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
w.WriteB(byte(t))
|
||||
}
|
||||
n.EncodeBinary(w)
|
||||
}
|
||||
|
||||
|
@ -94,9 +104,19 @@ func DecodeNodeWithType(r *io.BinReader) Node {
|
|||
case ExtensionT:
|
||||
n = new(ExtensionNode)
|
||||
case HashT:
|
||||
n = new(HashNode)
|
||||
n = &HashNode{
|
||||
BaseNode: BaseNode{
|
||||
hashValid: true,
|
||||
},
|
||||
}
|
||||
case LeafT:
|
||||
n = new(LeafNode)
|
||||
case EmptyT:
|
||||
n = &HashNode{
|
||||
BaseNode: BaseNode{
|
||||
hashValid: false,
|
||||
},
|
||||
}
|
||||
default:
|
||||
r.Err = fmt.Errorf("invalid node type: %x", typ)
|
||||
return nil
|
||||
|
|
|
@ -2,7 +2,6 @@ package mpt
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
@ -46,25 +45,17 @@ func (h *HashNode) Bytes() []byte {
|
|||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (h *HashNode) DecodeBinary(r *io.BinReader) {
|
||||
sz := r.ReadVarUint()
|
||||
switch sz {
|
||||
case 0:
|
||||
h.hashValid = false
|
||||
case util.Uint256Size:
|
||||
h.hashValid = true
|
||||
r.ReadBytes(h.hash[:])
|
||||
default:
|
||||
r.Err = fmt.Errorf("invalid hash node size: %d", sz)
|
||||
if h.hashValid {
|
||||
h.hash.DecodeBinary(r)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (h HashNode) EncodeBinary(w *io.BinWriter) {
|
||||
if !h.hashValid {
|
||||
w.WriteVarUint(0)
|
||||
return
|
||||
}
|
||||
w.WriteVarBytes(h.hash[:])
|
||||
w.WriteBytes(h.hash[:])
|
||||
}
|
||||
|
||||
// EncodeBinaryAsChild implements BaseNode interface.
|
||||
|
|
|
@ -18,6 +18,7 @@ const (
|
|||
ExtensionT NodeType = 0x01
|
||||
LeafT NodeType = 0x02
|
||||
HashT NodeType = 0x03
|
||||
EmptyT NodeType = 0x04
|
||||
)
|
||||
|
||||
// NodeObject represents Node together with it's type.
|
||||
|
|
|
@ -16,6 +16,9 @@ func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
|
|||
t.Run("IO", func(t *testing.T) {
|
||||
bs, err := testserdes.EncodeBinary(expected)
|
||||
require.NoError(t, err)
|
||||
if hn, ok := actual.(*HashNode); ok {
|
||||
hn.hashValid = true // this field is set during NodeObject decoding
|
||||
}
|
||||
err = testserdes.DecodeBinary(bs, actual)
|
||||
if !ok {
|
||||
require.Error(t, err)
|
||||
|
@ -80,8 +83,8 @@ func TestNode_Serializable(t *testing.T) {
|
|||
})
|
||||
t.Run("InvalidSize", func(t *testing.T) {
|
||||
buf := io.NewBufBinWriter()
|
||||
buf.BinWriter.WriteVarBytes(make([]byte, 13))
|
||||
require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode)))
|
||||
buf.BinWriter.WriteBytes(make([]byte, 13))
|
||||
require.Error(t, testserdes.DecodeBinary(buf.Bytes(), &HashNode{BaseNode: BaseNode{hashValid: true}}))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -151,6 +154,6 @@ func TestRootHash(t *testing.T) {
|
|||
b.Children[9] = l2
|
||||
|
||||
r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1)
|
||||
require.Equal(t, "a6d1385fa2e089fd9ca79e58bee47cb4c9c949140a382580138840113412931d", r1.Hash().StringLE())
|
||||
require.Equal(t, "62d14dc02b9f905ca6ec73fb499b1eef835e482d936744e3b6298cf9ad26ba03", r.Hash().StringLE())
|
||||
require.Equal(t, "cedd9897dd1559fbd5dfe5cfb223464da6de438271028afb8d647e950cbd18e0", r1.Hash().StringLE())
|
||||
require.Equal(t, "1037e779c8a0313bd0d99c4151fa70a277c43c53a549b6444079f2e67e8ffb7b", r.Hash().StringLE())
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package result
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -19,7 +22,6 @@ func testProofWithKey() *ProofWithKey {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestGetProof_MarshalJSON(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
p := testProofWithKey()
|
||||
|
@ -40,7 +42,7 @@ func TestGetProof_MarshalJSON(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
*/
|
||||
|
||||
func TestProofWithKey_EncodeString(t *testing.T) {
|
||||
expected := testProofWithKey()
|
||||
var actual ProofWithKey
|
||||
|
|
Loading…
Reference in a new issue