mpt: split HashNode in two types

First type is non-empty HashNode, and the second one is an Empty node.
This commit is contained in:
Anna Shaleva 2021-03-30 20:35:41 +03:00
parent b9927c39ee
commit 7f038bd465
5 changed files with 37 additions and 20 deletions
pkg

View file

@ -78,7 +78,17 @@ func (b *BaseNode) invalidateCache() {
// encodeNodeWithType encodes node together with it's type. // encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) { func encodeNodeWithType(n Node, w *io.BinWriter) {
w.WriteB(byte(n.Type())) switch t := n.Type(); t {
case HashT:
hn := n.(*HashNode)
if !hn.hashValid {
w.WriteB(byte(EmptyT))
break
}
fallthrough
default:
w.WriteB(byte(t))
}
n.EncodeBinary(w) n.EncodeBinary(w)
} }
@ -94,9 +104,19 @@ func DecodeNodeWithType(r *io.BinReader) Node {
case ExtensionT: case ExtensionT:
n = new(ExtensionNode) n = new(ExtensionNode)
case HashT: case HashT:
n = new(HashNode) n = &HashNode{
BaseNode: BaseNode{
hashValid: true,
},
}
case LeafT: case LeafT:
n = new(LeafNode) n = new(LeafNode)
case EmptyT:
n = &HashNode{
BaseNode: BaseNode{
hashValid: false,
},
}
default: default:
r.Err = fmt.Errorf("invalid node type: %x", typ) r.Err = fmt.Errorf("invalid node type: %x", typ)
return nil return nil

View file

@ -2,7 +2,6 @@ package mpt
import ( import (
"errors" "errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -46,25 +45,17 @@ func (h *HashNode) Bytes() []byte {
// DecodeBinary implements io.Serializable. // DecodeBinary implements io.Serializable.
func (h *HashNode) DecodeBinary(r *io.BinReader) { func (h *HashNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint() if h.hashValid {
switch sz { h.hash.DecodeBinary(r)
case 0:
h.hashValid = false
case util.Uint256Size:
h.hashValid = true
r.ReadBytes(h.hash[:])
default:
r.Err = fmt.Errorf("invalid hash node size: %d", sz)
} }
} }
// EncodeBinary implements io.Serializable. // EncodeBinary implements io.Serializable.
func (h HashNode) EncodeBinary(w *io.BinWriter) { func (h HashNode) EncodeBinary(w *io.BinWriter) {
if !h.hashValid { if !h.hashValid {
w.WriteVarUint(0)
return return
} }
w.WriteVarBytes(h.hash[:]) w.WriteBytes(h.hash[:])
} }
// EncodeBinaryAsChild implements BaseNode interface. // EncodeBinaryAsChild implements BaseNode interface.

View file

@ -18,6 +18,7 @@ const (
ExtensionT NodeType = 0x01 ExtensionT NodeType = 0x01
LeafT NodeType = 0x02 LeafT NodeType = 0x02
HashT NodeType = 0x03 HashT NodeType = 0x03
EmptyT NodeType = 0x04
) )
// NodeObject represents Node together with it's type. // NodeObject represents Node together with it's type.

View file

@ -16,6 +16,9 @@ func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
t.Run("IO", func(t *testing.T) { t.Run("IO", func(t *testing.T) {
bs, err := testserdes.EncodeBinary(expected) bs, err := testserdes.EncodeBinary(expected)
require.NoError(t, err) require.NoError(t, err)
if hn, ok := actual.(*HashNode); ok {
hn.hashValid = true // this field is set during NodeObject decoding
}
err = testserdes.DecodeBinary(bs, actual) err = testserdes.DecodeBinary(bs, actual)
if !ok { if !ok {
require.Error(t, err) require.Error(t, err)
@ -80,8 +83,8 @@ func TestNode_Serializable(t *testing.T) {
}) })
t.Run("InvalidSize", func(t *testing.T) { t.Run("InvalidSize", func(t *testing.T) {
buf := io.NewBufBinWriter() buf := io.NewBufBinWriter()
buf.BinWriter.WriteVarBytes(make([]byte, 13)) buf.BinWriter.WriteBytes(make([]byte, 13))
require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode))) require.Error(t, testserdes.DecodeBinary(buf.Bytes(), &HashNode{BaseNode: BaseNode{hashValid: true}}))
}) })
}) })
@ -151,6 +154,6 @@ func TestRootHash(t *testing.T) {
b.Children[9] = l2 b.Children[9] = l2
r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1) r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1)
require.Equal(t, "a6d1385fa2e089fd9ca79e58bee47cb4c9c949140a382580138840113412931d", r1.Hash().StringLE()) require.Equal(t, "cedd9897dd1559fbd5dfe5cfb223464da6de438271028afb8d647e950cbd18e0", r1.Hash().StringLE())
require.Equal(t, "62d14dc02b9f905ca6ec73fb499b1eef835e482d936744e3b6298cf9ad26ba03", r.Hash().StringLE()) require.Equal(t, "1037e779c8a0313bd0d99c4151fa70a277c43c53a549b6444079f2e67e8ffb7b", r.Hash().StringLE())
} }

View file

@ -1,10 +1,13 @@
package result package result
import ( import (
"encoding/json"
"testing" "testing"
"github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -19,7 +22,6 @@ func testProofWithKey() *ProofWithKey {
} }
} }
/*
func TestGetProof_MarshalJSON(t *testing.T) { func TestGetProof_MarshalJSON(t *testing.T) {
t.Run("Good", func(t *testing.T) { t.Run("Good", func(t *testing.T) {
p := testProofWithKey() p := testProofWithKey()
@ -40,7 +42,7 @@ func TestGetProof_MarshalJSON(t *testing.T) {
} }
}) })
} }
*/
func TestProofWithKey_EncodeString(t *testing.T) { func TestProofWithKey_EncodeString(t *testing.T) {
expected := testProofWithKey() expected := testProofWithKey()
var actual ProofWithKey var actual ProofWithKey