core: mandate passing from as a subprefix for (*Trie).Find

However, we need to distinguish empty subprefix and nil subprefix (no
start specified) to match the C# behaviour.
This commit is contained in:
Anna Shaleva 2021-10-13 11:38:53 +03:00
parent 8e7c76827b
commit 892eadf86d
4 changed files with 37 additions and 19 deletions

View file

@ -394,13 +394,16 @@ func TestCompatibility_Find(t *testing.T) {
})
t.Run("from is not in tree", func(t *testing.T) {
t.Run("matching", func(t *testing.T) {
check(t, []byte("aa30"), 1)
check(t, []byte("30"), 1)
})
t.Run("non-matching", func(t *testing.T) {
check(t, []byte("aa60"), 0)
check(t, []byte("60"), 0)
})
})
t.Run("from is in tree", func(t *testing.T) {
check(t, []byte("aa10"), 1) // without `from` key
check(t, []byte("10"), 1) // without `from` key
})
t.Run("from matching start", func(t *testing.T) {
check(t, []byte{}, 2) // without `from` key
})
}

View file

@ -526,22 +526,19 @@ func collapse(depth int, node Node) Node {
}
// Find returns list of storage key-value pairs whose key is prefixed by the specified
// prefix starting from the specified path (not including the item at the specified path
// if so). The `max` number of elements is returned at max.
// prefix starting from the specified `prefix`+`from` path (not including the item at
// the specified `prefix`+`from` path if so). The `max` number of elements is returned at max.
func (t *Trie) Find(prefix, from []byte, max int) ([]storage.KeyValue, error) {
if len(prefix) > MaxKeyLength {
return nil, errors.New("invalid prefix length")
}
if len(from) > MaxKeyLength {
if len(from) > MaxKeyLength-len(prefix) {
return nil, errors.New("invalid from length")
}
prefixP := toNibbles(prefix)
fromP := []byte{}
if len(from) > 0 {
if !bytes.HasPrefix(from, prefix) {
return nil, errors.New("`from` argument doesn't match specified prefix")
}
fromP = toNibbles(from)[len(prefixP):]
fromP = toNibbles(from)
}
_, start, path, err := t.getWithPath(t.root, prefixP, false)
if err != nil {
@ -572,10 +569,9 @@ func (t *Trie) Find(prefix, from []byte, max int) ([]storage.KeyValue, error) {
b := NewBillet(t.root.Hash(), false, t.Store)
process := func(pathToNode []byte, node Node, _ []byte) bool {
if leaf, ok := node.(*LeafNode); ok {
key := append(prefix, pathToNode...)
if !bytes.Equal(key, from) { // (*Billet).traverse includes `from` path into result if so. Need to filter out manually.
if from == nil || !bytes.Equal(pathToNode, from) { // (*Billet).traverse includes `from` path into result if so. Need to filter out manually.
res = append(res, storage.KeyValue{
Key: key,
Key: append(slice.Copy(prefix), pathToNode...),
Value: slice.Copy(leaf.value),
})
count++

View file

@ -1095,6 +1095,15 @@ func (s *Server) findStates(ps request.Params) (interface{}, *response.Error) {
if err != nil {
return nil, response.WrapErrorWithData(response.ErrInvalidParams, errors.New("invalid key"))
}
if len(key) > 0 {
if !bytes.HasPrefix(key, prefix) {
return nil, response.WrapErrorWithData(response.ErrInvalidParams, errors.New("key doesn't match prefix"))
}
key = key[len(prefix):]
} else {
// empty ("") key shouldn't exclude item matching prefix from the result
key = nil
}
}
if len(ps) > 4 {
count, err = ps.Value(4).GetInt()
@ -1110,11 +1119,7 @@ func (s *Server) findStates(ps request.Params) (interface{}, *response.Error) {
return nil, respErr
}
pKey := makeStorageKey(cs.ID, prefix)
var sKey []byte
if len(key) > 0 {
sKey = makeStorageKey(cs.ID, key)
}
kvs, err := s.chain.GetStateModule().FindStates(root, pKey, sKey, count+1) // +1 to define result truncation
kvs, err := s.chain.GetStateModule().FindStates(root, pKey, key, count+1) // +1 to define result truncation
if err != nil {
return nil, response.NewInternalServerError("failed to find historical items", err)
}

View file

@ -1515,6 +1515,20 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []
Truncated: false,
})
})
t.Run("good: empty prefix, no limit", func(t *testing.T) {
// empty prefix should be considered as no prefix specified.
root, err := e.chain.GetStateModule().GetStateRoot(16)
require.NoError(t, err)
params := fmt.Sprintf(`"%s", "%s", "%s", ""`, root.Root.StringLE(), testContractHash, base64.StdEncoding.EncodeToString([]byte("aa")))
testFindStates(t, params, root.Root, result.FindStates{
Results: []result.KeyValue{
{Key: []byte("aa10"), Value: []byte("v2")},
{Key: []byte("aa50"), Value: []byte("v3")},
{Key: []byte("aa"), Value: []byte("v1")},
},
Truncated: false,
})
})
t.Run("good: with prefix, no limit", func(t *testing.T) {
// pairs for this test where put to the contract storage at block #16
root, err := e.chain.GetStateModule().GetStateRoot(16)
@ -1527,7 +1541,7 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []
Truncated: false,
})
})
t.Run("good: no prefix, with limit", func(t *testing.T) {
t.Run("good: empty prefix, with limit", func(t *testing.T) {
for limit := 2; limit < 5; limit++ {
// pairs for this test where put to the contract storage at block #16
root, err := e.chain.GetStateModule().GetStateRoot(16)