mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-05 19:35:48 +00:00
commit
fbadb317f5
62 changed files with 3804 additions and 347 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
@ -33,35 +34,7 @@ func toNeoStorageKey(key []byte) []byte {
|
|||
if len(key) < util.Uint160Size {
|
||||
panic("invalid key in storage")
|
||||
}
|
||||
|
||||
var nkey []byte
|
||||
for i := util.Uint160Size - 1; i >= 0; i-- {
|
||||
nkey = append(nkey, key[i])
|
||||
}
|
||||
|
||||
key = key[util.Uint160Size:]
|
||||
|
||||
index := 0
|
||||
remain := len(key)
|
||||
for remain >= 16 {
|
||||
nkey = append(nkey, key[index:index+16]...)
|
||||
nkey = append(nkey, 0)
|
||||
index += 16
|
||||
remain -= 16
|
||||
}
|
||||
|
||||
if remain > 0 {
|
||||
nkey = append(nkey, key[index:]...)
|
||||
}
|
||||
|
||||
padding := 16 - remain
|
||||
for i := 0; i < padding; i++ {
|
||||
nkey = append(nkey, 0)
|
||||
}
|
||||
|
||||
nkey = append(nkey, byte(padding))
|
||||
|
||||
return nkey
|
||||
return mpt.ToNeoStorageKey(key)
|
||||
}
|
||||
|
||||
// batchToMap converts batch to a map so that JSON is compatible
|
||||
|
|
|
@ -2,6 +2,8 @@ ProtocolConfiguration:
|
|||
Magic: 1953787457
|
||||
AddressVersion: 23
|
||||
SecondsPerBlock: 15
|
||||
EnableStateRoot: true
|
||||
StateRootEnableIndex: 4380100
|
||||
LowPriorityThreshold: 0.000
|
||||
MemPoolSize: 50000
|
||||
StandbyValidators:
|
||||
|
|
|
@ -2,6 +2,7 @@ ProtocolConfiguration:
|
|||
Magic: 56753
|
||||
AddressVersion: 23
|
||||
SecondsPerBlock: 15
|
||||
EnableStateRoot: true
|
||||
LowPriorityThreshold: 0.000
|
||||
MemPoolSize: 50000
|
||||
StandbyValidators:
|
||||
|
|
3
go.mod
3
go.mod
|
@ -3,12 +3,13 @@ module github.com/nspcc-dev/neo-go
|
|||
require (
|
||||
github.com/Workiva/go-datastructures v1.0.50
|
||||
github.com/alicebob/miniredis v2.5.0+incompatible
|
||||
github.com/btcsuite/btcd v0.20.1-beta
|
||||
github.com/dgraph-io/badger/v2 v2.0.3
|
||||
github.com/go-redis/redis v6.10.2+incompatible
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/mr-tron/base58 v1.1.2
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e
|
||||
github.com/nspcc-dev/rfc6979 v0.2.0
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/prometheus/client_golang v1.2.1
|
||||
|
|
32
go.sum
32
go.sum
|
@ -13,6 +13,8 @@ github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma
|
|||
github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg=
|
||||
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8=
|
||||
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db/go.mod h1:rB3B4rKii8V21ydCbIzH5hZiCQE7f5E9SzUb/ZZx530=
|
||||
github.com/aead/siphash v1.0.1 h1:FwHfE/T45KPKYuuSAKyyvE+oPWcaQ+CUmFW0bPlM+kg=
|
||||
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
|
@ -28,6 +30,22 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
|
|||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/btcsuite/btcd v0.20.1-beta h1:Ik4hyJqN8Jfyv3S4AGBOmyouMsYE3EdYODkMbQjwPGw=
|
||||
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
|
||||
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f h1:bAs4lUbRJpnnkd9VhRV3jjAVU7DJVjMaK+IsvSeZvFo=
|
||||
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
|
||||
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d h1:yJzD/yFppdVCf6ApMkVy8cUxV0XrxdP9rVf6D87/Mng=
|
||||
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg=
|
||||
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd h1:R/opQEbFEy9JGkIguV40SvRY1uliPX8ifOvi6ICsFCw=
|
||||
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg=
|
||||
github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd h1:qdGvebPBDuYDPGi1WCPjy1tGyMpmDK8IEapSsszn7HE=
|
||||
github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY=
|
||||
github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723 h1:ZA/jbKoGcVAnER6pCHPEkGdZOV7U1oLUedErBHCUMs0=
|
||||
github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc=
|
||||
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 h1:R8vQdOQdZ9Y3SkEwmHoWBmX1DNXhXZqlTpq6s4tyJGc=
|
||||
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY=
|
||||
github.com/btcsuite/winsvc v1.0.0 h1:J9B4L7e3oqhXOcm+2IuNApwzQec85lE+QaikUcCs+dk=
|
||||
github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs=
|
||||
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
|
||||
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
||||
github.com/cespare/xxhash/v2 v2.1.0 h1:yTUvW7Vhb89inJ+8irsUqiWjh8iT6sQPZiQzI6ReGkA=
|
||||
|
@ -41,6 +59,7 @@ github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc
|
|||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
|
||||
github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
|
@ -91,9 +110,15 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T
|
|||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89 h1:12K8AlpT0/6QUXSfV0yi4Q0jkbq8NDtIKFtF61AoqV0=
|
||||
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jrick/logrotate v1.0.0 h1:lQ1bL/n9mBNeIXoTUoYRlK4dHuNJVofX9oWqBtPnSzI=
|
||||
github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23 h1:FOOIBWrEkLgmlgGfMuZT83xIwfPDxEI2OHu6xUmJMFE=
|
||||
github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
|
@ -131,8 +156,8 @@ github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a h1:ajvxgEe9qY4vvoSm
|
|||
github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a/go.mod h1:/YFK+XOxxg0Bfm6P92lY5eDSLYfp06XOdL8KAVgXjVk=
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1 h1:yEx9WznS+rjE0jl0dLujCxuZSIb+UTjF+005TJu/nNI=
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1/go.mod h1:O0qtn62prQSqizzoagHmuuKoz8QMkU3SzBoKdEvm3aQ=
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac h1:cXPgsp4avJ7cR1nPRdpFRHmWoMSRZ41FSvlNjpsyTiA=
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac/go.mod h1:1FYQXSbb6/9HQIkoF8XO7W/S8N7AZRkBsgwbcXRvk0E=
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e h1:QOT9slflIkEKb5wY0ZUC0dCmCgoqGlhOAh9+xWMIxfg=
|
||||
github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e/go.mod h1:1FYQXSbb6/9HQIkoF8XO7W/S8N7AZRkBsgwbcXRvk0E=
|
||||
github.com/nspcc-dev/neo-go v0.73.1-pre.0.20200303142215-f5a1b928ce09/go.mod h1:pPYwPZ2ks+uMnlRLUyXOpLieaDQSEaf4NM3zHVbRjmg=
|
||||
github.com/nspcc-dev/neofs-crypto v0.2.0 h1:ftN+59WqxSWz/RCgXYOfhmltOOqU+udsNQSvN6wkFck=
|
||||
github.com/nspcc-dev/neofs-crypto v0.2.0/go.mod h1:F/96fUzPM3wR+UGsPi3faVNmFlA9KAEAUQR7dMxZmNA=
|
||||
|
@ -144,10 +169,12 @@ github.com/nspcc-dev/rfc6979 v0.2.0 h1:3e1WNxrN60/6N0DW7+UYisLeZJyfqZTNOjeV/toYv
|
|||
github.com/nspcc-dev/rfc6979 v0.2.0/go.mod h1:exhIh1PdpDC5vQmyEsGvc4YDM/lyQp/452QxGq/UEso=
|
||||
github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.10.3 h1:OoxbjfXVZyod1fmWYhI7SEyaD8B00ynP3T+D5GiyHOY=
|
||||
github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I=
|
||||
github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
|
@ -212,6 +239,7 @@ go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
|
|||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM=
|
||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
|
|
|
@ -11,6 +11,10 @@ var syscalls = map[string]map[string]string{
|
|||
"GetUsage": "Neo.Attribute.GetUsage",
|
||||
"GetData": "Neo.Attribute.GetData",
|
||||
},
|
||||
"crypto": {
|
||||
"Secp256k1Recover": "Neo.Cryptography.Secp256k1Recover",
|
||||
"Secp256r1Recover": "Neo.Cryptography.Secp256r1Recover",
|
||||
},
|
||||
"enumerator": {
|
||||
"Concat": "Neo.Enumerator.Concat",
|
||||
"Create": "Neo.Enumerator.Create",
|
||||
|
|
|
@ -20,6 +20,8 @@ const (
|
|||
type (
|
||||
ProtocolConfiguration struct {
|
||||
AddressVersion byte `yaml:"AddressVersion"`
|
||||
// EnableStateRoot specifies if exchange of state roots should be enabled.
|
||||
EnableStateRoot bool `yaml:"EnableStateRoot"`
|
||||
// FeePerExtraByte sets the expected per-byte fee for
|
||||
// transactions exceeding the MaxFreeTransactionSize.
|
||||
FeePerExtraByte float64 `yaml:"FeePerExtraByte"`
|
||||
|
@ -34,11 +36,13 @@ type (
|
|||
MaxFreeTransactionsPerBlock int `yaml:"MaxFreeTransactionsPerBlock"`
|
||||
MemPoolSize int `yaml:"MemPoolSize"`
|
||||
// SaveStorageBatch enables storage batch saving before every persist.
|
||||
SaveStorageBatch bool `yaml:"SaveStorageBatch"`
|
||||
SecondsPerBlock int `yaml:"SecondsPerBlock"`
|
||||
SeedList []string `yaml:"SeedList"`
|
||||
StandbyValidators []string `yaml:"StandbyValidators"`
|
||||
SystemFee SystemFee `yaml:"SystemFee"`
|
||||
SaveStorageBatch bool `yaml:"SaveStorageBatch"`
|
||||
SecondsPerBlock int `yaml:"SecondsPerBlock"`
|
||||
SeedList []string `yaml:"SeedList"`
|
||||
StandbyValidators []string `yaml:"StandbyValidators"`
|
||||
// StateRootEnableIndex specifies starting height for state root calculations and exchange.
|
||||
StateRootEnableIndex uint32 `yaml:"StateRootEnableIndex"`
|
||||
SystemFee SystemFee `yaml:"SystemFee"`
|
||||
// Whether to verify received blocks.
|
||||
VerifyBlocks bool `yaml:"VerifyBlocks"`
|
||||
// Whether to verify transactions in received blocks.
|
||||
|
|
|
@ -8,6 +8,9 @@ import (
|
|||
// commit represents dBFT Commit message.
|
||||
type commit struct {
|
||||
signature [signatureSize]byte
|
||||
stateSig [signatureSize]byte
|
||||
|
||||
stateRootEnabled bool
|
||||
}
|
||||
|
||||
// signatureSize is an rfc6989 signature size in bytes
|
||||
|
@ -19,11 +22,17 @@ var _ payload.Commit = (*commit)(nil)
|
|||
// EncodeBinary implements io.Serializable interface.
|
||||
func (c *commit) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteBytes(c.signature[:])
|
||||
if c.stateRootEnabled {
|
||||
w.WriteBytes(c.stateSig[:])
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable interface.
|
||||
func (c *commit) DecodeBinary(r *io.BinReader) {
|
||||
r.ReadBytes(c.signature[:])
|
||||
if c.stateRootEnabled {
|
||||
r.ReadBytes(c.stateSig[:])
|
||||
}
|
||||
}
|
||||
|
||||
// Signature implements payload.Commit interface.
|
||||
|
|
|
@ -2,6 +2,7 @@ package consensus
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"time"
|
||||
|
@ -13,7 +14,9 @@ import (
|
|||
"github.com/nspcc-dev/dbft/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
coreb "github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/cache"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
|
@ -49,9 +52,9 @@ type service struct {
|
|||
|
||||
log *zap.Logger
|
||||
// cache is a fifo cache which stores recent payloads.
|
||||
cache *relayCache
|
||||
cache *cache.HashCache
|
||||
// txx is a fifo cache which stores miner transactions.
|
||||
txx *relayCache
|
||||
txx *cache.HashCache
|
||||
dbft *dbft.DBFT
|
||||
// messages and transactions are channels needed to process
|
||||
// everything in single thread.
|
||||
|
@ -70,7 +73,7 @@ type Config struct {
|
|||
Logger *zap.Logger
|
||||
// Broadcast is a callback which is called to notify server
|
||||
// about new consensus payload to sent.
|
||||
Broadcast func(p *Payload)
|
||||
Broadcast func(cache.Hashable)
|
||||
// Chain is a core.Blockchainer instance.
|
||||
Chain core.Blockchainer
|
||||
// RequestTx is a callback to which will be called
|
||||
|
@ -96,8 +99,8 @@ func NewService(cfg Config) (Service, error) {
|
|||
Config: cfg,
|
||||
|
||||
log: cfg.Logger,
|
||||
cache: newFIFOCache(cacheMaxCapacity),
|
||||
txx: newFIFOCache(cacheMaxCapacity),
|
||||
cache: cache.NewFIFOCache(cacheMaxCapacity),
|
||||
txx: cache.NewFIFOCache(cacheMaxCapacity),
|
||||
messages: make(chan Payload, 100),
|
||||
|
||||
transactions: make(chan *transaction.Transaction, 100),
|
||||
|
@ -120,7 +123,6 @@ func NewService(cfg Config) (Service, error) {
|
|||
dbft.WithLogger(srv.log),
|
||||
dbft.WithSecondsPerBlock(cfg.TimePerBlock),
|
||||
dbft.WithGetKeyPair(srv.getKeyPair),
|
||||
dbft.WithTxPerBlock(10000),
|
||||
dbft.WithRequestTx(cfg.RequestTx),
|
||||
dbft.WithGetTx(srv.getTx),
|
||||
dbft.WithGetVerified(srv.getVerifiedTx),
|
||||
|
@ -135,13 +137,14 @@ func NewService(cfg Config) (Service, error) {
|
|||
dbft.WithGetValidators(srv.getValidators),
|
||||
dbft.WithGetConsensusAddress(srv.getConsensusAddress),
|
||||
|
||||
dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { p := new(Payload); p.message = &message{}; return p }),
|
||||
dbft.WithNewPrepareRequest(func() payload.PrepareRequest { return new(prepareRequest) }),
|
||||
dbft.WithNewConsensusPayload(srv.newPayload),
|
||||
dbft.WithNewPrepareRequest(srv.newPrepareRequest),
|
||||
dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }),
|
||||
dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }),
|
||||
dbft.WithNewCommit(func() payload.Commit { return new(commit) }),
|
||||
dbft.WithNewCommit(srv.newCommit),
|
||||
dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }),
|
||||
dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }),
|
||||
dbft.WithVerifyPrepareRequest(srv.verifyRequest),
|
||||
)
|
||||
|
||||
if srv.dbft == nil {
|
||||
|
@ -210,6 +213,53 @@ func (s *service) eventLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *service) newPayload() payload.ConsensusPayload {
|
||||
return &Payload{
|
||||
message: &message{
|
||||
stateRootEnabled: s.stateRootEnabled(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// stateRootEnabled checks if state root feature is enabled on current height.
|
||||
// It should be called only from dbft callbacks and is not protected by any mutex.
|
||||
func (s *service) stateRootEnabled() bool {
|
||||
return s.Chain.GetConfig().EnableStateRoot
|
||||
}
|
||||
|
||||
func (s *service) newPrepareRequest() payload.PrepareRequest {
|
||||
if !s.stateRootEnabled() {
|
||||
return new(prepareRequest)
|
||||
}
|
||||
sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight())
|
||||
if err == nil {
|
||||
return &prepareRequest{
|
||||
stateRootEnabled: true,
|
||||
proposalStateRoot: sr.MPTRootBase,
|
||||
}
|
||||
}
|
||||
return &prepareRequest{stateRootEnabled: true}
|
||||
}
|
||||
|
||||
func (s *service) newCommit() payload.Commit {
|
||||
if !s.stateRootEnabled() {
|
||||
return new(commit)
|
||||
}
|
||||
c := &commit{stateRootEnabled: true}
|
||||
for _, p := range s.dbft.Context.PreparationPayloads {
|
||||
if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType {
|
||||
pr := p.GetPrepareRequest().(*prepareRequest)
|
||||
data := pr.proposalStateRoot.GetSignedPart()
|
||||
sign, err := s.dbft.Priv.Sign(data)
|
||||
if err == nil {
|
||||
copy(c.stateSig[:], sign)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *service) validatePayload(p *Payload) bool {
|
||||
validators := s.getValidators()
|
||||
if int(p.validatorIndex) >= len(validators) {
|
||||
|
@ -262,8 +312,8 @@ func (s *service) OnPayload(cp *Payload) {
|
|||
|
||||
// decode payload data into message
|
||||
if cp.message == nil {
|
||||
if err := cp.decodeData(); err != nil {
|
||||
log.Debug("can't decode payload data")
|
||||
if err := cp.decodeData(s.stateRootEnabled()); err != nil {
|
||||
log.Debug("can't decode payload data", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -340,6 +390,21 @@ func (s *service) verifyBlock(b block.Block) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (s *service) verifyRequest(p payload.ConsensusPayload) error {
|
||||
if !s.stateRootEnabled() {
|
||||
return nil
|
||||
}
|
||||
r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't get local state root: %v", err)
|
||||
}
|
||||
rb := &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot
|
||||
if !r.Equals(rb) {
|
||||
return errors.New("state root mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) processBlock(b block.Block) {
|
||||
bb := &b.(*neoBlock).Block
|
||||
bb.Script = *(s.getBlockWitness(bb))
|
||||
|
@ -351,16 +416,36 @@ func (s *service) processBlock(b block.Block) {
|
|||
s.log.Warn("error on add block", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
var rb *state.MPTRootBase
|
||||
for _, p := range s.dbft.PreparationPayloads {
|
||||
if p != nil && p.Type() == payload.PrepareRequestType {
|
||||
rb = &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot
|
||||
}
|
||||
}
|
||||
w := s.getWitness(func(p payload.Commit) []byte { return p.(*commit).stateSig[:] })
|
||||
r := &state.MPTRoot{
|
||||
MPTRootBase: *rb,
|
||||
Witness: w,
|
||||
}
|
||||
if err := s.Chain.AddStateRoot(r); err != nil {
|
||||
s.log.Warn("errors while adding state root", zap.Error(err))
|
||||
}
|
||||
s.Broadcast(r)
|
||||
}
|
||||
|
||||
func (s *service) getBlockWitness(b *coreb.Block) *transaction.Witness {
|
||||
func (s *service) getBlockWitness(_ *coreb.Block) *transaction.Witness {
|
||||
return s.getWitness(func(p payload.Commit) []byte { return p.Signature() })
|
||||
}
|
||||
|
||||
func (s *service) getWitness(f func(p payload.Commit) []byte) *transaction.Witness {
|
||||
dctx := s.dbft.Context
|
||||
pubs := convertKeys(dctx.Validators)
|
||||
sigs := make(map[*keys.PublicKey][]byte)
|
||||
|
||||
for i := range pubs {
|
||||
if p := dctx.CommitPayloads[i]; p != nil && p.ViewNumber() == dctx.ViewNumber {
|
||||
sigs[pubs[i]] = p.GetCommit().Signature()
|
||||
sigs[pubs[i]] = f(p.GetCommit())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -397,7 +482,7 @@ func (s *service) getBlock(h util.Uint256) block.Block {
|
|||
return &neoBlock{Block: *b}
|
||||
}
|
||||
|
||||
func (s *service) getVerifiedTx(count int) []block.Transaction {
|
||||
func (s *service) getVerifiedTx() []block.Transaction {
|
||||
pool := s.Config.Chain.GetMemPool()
|
||||
|
||||
var txx []mempool.TxWithFee
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/nspcc-dev/dbft/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/cache"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
|
@ -25,7 +26,7 @@ func TestNewService(t *testing.T) {
|
|||
require.NoError(t, srv.Chain.PoolTx(tx))
|
||||
|
||||
var txx []block.Transaction
|
||||
require.NotPanics(t, func() { txx = srv.getVerifiedTx(1) })
|
||||
require.NotPanics(t, func() { txx = srv.getVerifiedTx() })
|
||||
require.Len(t, txx, 2)
|
||||
require.Equal(t, tx, txx[1])
|
||||
srv.Chain.Close()
|
||||
|
@ -58,7 +59,7 @@ func TestService_GetVerified(t *testing.T) {
|
|||
srv.dbft.ViewNumber = 1
|
||||
|
||||
t.Run("new transactions will be proposed in case of failure", func(t *testing.T) {
|
||||
txx := srv.getVerifiedTx(10)
|
||||
txx := srv.getVerifiedTx()
|
||||
require.Equal(t, 2, len(txx), "there is only 1 tx in mempool")
|
||||
require.Equal(t, txs[3], txx[1])
|
||||
})
|
||||
|
@ -68,7 +69,7 @@ func TestService_GetVerified(t *testing.T) {
|
|||
require.NoError(t, srv.Chain.PoolTx(tx))
|
||||
}
|
||||
|
||||
txx := srv.getVerifiedTx(10)
|
||||
txx := srv.getVerifiedTx()
|
||||
require.Contains(t, txx, txs[0])
|
||||
require.Contains(t, txx, txs[1])
|
||||
require.NotContains(t, txx, txs[2])
|
||||
|
@ -182,7 +183,7 @@ func shouldNotReceive(t *testing.T, ch chan Payload) {
|
|||
func newTestService(t *testing.T) *service {
|
||||
srv, err := NewService(Config{
|
||||
Logger: zaptest.NewLogger(t),
|
||||
Broadcast: func(*Payload) {},
|
||||
Broadcast: func(cache.Hashable) {},
|
||||
Chain: newTestChain(t),
|
||||
RequestTx: func(...util.Uint256) {},
|
||||
Wallet: &wallet.Config{
|
||||
|
|
|
@ -22,6 +22,8 @@ type (
|
|||
Type messageType
|
||||
ViewNumber byte
|
||||
|
||||
stateRootEnabled bool
|
||||
|
||||
payload io.Serializable
|
||||
}
|
||||
|
||||
|
@ -283,15 +285,21 @@ func (m *message) DecodeBinary(r *io.BinReader) {
|
|||
cv.newViewNumber = m.ViewNumber + 1
|
||||
m.payload = cv
|
||||
case prepareRequestType:
|
||||
m.payload = new(prepareRequest)
|
||||
m.payload = &prepareRequest{
|
||||
stateRootEnabled: m.stateRootEnabled,
|
||||
}
|
||||
case prepareResponseType:
|
||||
m.payload = new(prepareResponse)
|
||||
case commitType:
|
||||
m.payload = new(commit)
|
||||
m.payload = &commit{
|
||||
stateRootEnabled: m.stateRootEnabled,
|
||||
}
|
||||
case recoveryRequestType:
|
||||
m.payload = new(recoveryRequest)
|
||||
case recoveryMessageType:
|
||||
m.payload = new(recoveryMessage)
|
||||
m.payload = &recoveryMessage{
|
||||
stateRootEnabled: m.stateRootEnabled,
|
||||
}
|
||||
default:
|
||||
r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type))
|
||||
return
|
||||
|
@ -319,9 +327,9 @@ func (t messageType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
// decode data of payload into it's message
|
||||
func (p *Payload) decodeData() error {
|
||||
m := new(message)
|
||||
// decodeData decodes data of payload into it's message.
|
||||
func (p *Payload) decodeData(stateRootEnabled bool) error {
|
||||
m := &message{stateRootEnabled: stateRootEnabled}
|
||||
br := io.NewBinReaderFromBuf(p.data)
|
||||
m.DecodeBinary(br)
|
||||
if br.Err != nil {
|
||||
|
|
|
@ -94,13 +94,13 @@ func TestConsensusPayload_Serializable(t *testing.T) {
|
|||
// message is nil after decoding as we didn't yet call decodeData
|
||||
require.Nil(t, actual.message)
|
||||
// message should now be decoded from actual.data byte array
|
||||
assert.NoError(t, actual.decodeData())
|
||||
assert.NoError(t, actual.decodeData(false))
|
||||
require.Equal(t, p, actual)
|
||||
|
||||
data = p.MarshalUnsigned()
|
||||
pu := new(Payload)
|
||||
require.NoError(t, pu.UnmarshalUnsigned(data))
|
||||
assert.NoError(t, pu.decodeData())
|
||||
assert.NoError(t, pu.decodeData(false))
|
||||
|
||||
p.Witness = transaction.Witness{}
|
||||
require.Equal(t, p, pu)
|
||||
|
@ -144,14 +144,14 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
|
|||
p := new(Payload)
|
||||
require.NoError(t, testserdes.DecodeBinary(buf, p))
|
||||
// decode `data` into `message`
|
||||
assert.NoError(t, p.decodeData())
|
||||
assert.NoError(t, p.decodeData(false))
|
||||
require.Equal(t, expected, p)
|
||||
|
||||
// invalid type
|
||||
buf[typeIndex] = 0xFF
|
||||
actual := new(Payload)
|
||||
require.NoError(t, testserdes.DecodeBinary(buf, actual))
|
||||
require.Error(t, actual.decodeData())
|
||||
require.Error(t, actual.decodeData(false))
|
||||
|
||||
// invalid format
|
||||
buf[delimeterIndex] = 0
|
||||
|
@ -165,9 +165,16 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
|
|||
require.Error(t, testserdes.DecodeBinary(buf, new(Payload)))
|
||||
}
|
||||
|
||||
func testEncodeDecode(srEnabled bool, mt messageType, actual io.Serializable) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
expected := randomMessage(t, mt, srEnabled)
|
||||
testserdes.EncodeDecodeBinary(t, expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommit_Serializable(t *testing.T) {
|
||||
c := randomMessage(t, commitType)
|
||||
testserdes.EncodeDecodeBinary(t, c, new(commit))
|
||||
t.Run("WithStateRoot", testEncodeDecode(true, commitType, &commit{stateRootEnabled: true}))
|
||||
t.Run("NoStateRoot", testEncodeDecode(false, commitType, &commit{stateRootEnabled: false}))
|
||||
}
|
||||
|
||||
func TestPrepareResponse_Serializable(t *testing.T) {
|
||||
|
@ -176,8 +183,8 @@ func TestPrepareResponse_Serializable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPrepareRequest_Serializable(t *testing.T) {
|
||||
req := randomMessage(t, prepareRequestType)
|
||||
testserdes.EncodeDecodeBinary(t, req, new(prepareRequest))
|
||||
t.Run("WithStateRoot", testEncodeDecode(true, prepareRequestType, &prepareRequest{stateRootEnabled: true}))
|
||||
t.Run("NoStateRoot", testEncodeDecode(false, prepareRequestType, &prepareRequest{stateRootEnabled: false}))
|
||||
}
|
||||
|
||||
func TestRecoveryRequest_Serializable(t *testing.T) {
|
||||
|
@ -186,8 +193,8 @@ func TestRecoveryRequest_Serializable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRecoveryMessage_Serializable(t *testing.T) {
|
||||
msg := randomMessage(t, recoveryMessageType)
|
||||
testserdes.EncodeDecodeBinary(t, msg, new(recoveryMessage))
|
||||
t.Run("WithStateRoot", testEncodeDecode(true, recoveryMessageType, &recoveryMessage{stateRootEnabled: true}))
|
||||
t.Run("NoStateRoot", testEncodeDecode(false, recoveryMessageType, &recoveryMessage{stateRootEnabled: false}))
|
||||
}
|
||||
|
||||
func randomPayload(t *testing.T, mt messageType) *Payload {
|
||||
|
@ -215,31 +222,35 @@ func randomPayload(t *testing.T, mt messageType) *Payload {
|
|||
return p
|
||||
}
|
||||
|
||||
func randomMessage(t *testing.T, mt messageType) io.Serializable {
|
||||
func randomMessage(t *testing.T, mt messageType, srEnabled ...bool) io.Serializable {
|
||||
switch mt {
|
||||
case changeViewType:
|
||||
return &changeView{
|
||||
timestamp: rand.Uint32(),
|
||||
}
|
||||
case prepareRequestType:
|
||||
return randomPrepareRequest(t)
|
||||
return randomPrepareRequest(t, srEnabled...)
|
||||
case prepareResponseType:
|
||||
return &prepareResponse{preparationHash: random.Uint256()}
|
||||
case commitType:
|
||||
var c commit
|
||||
random.Fill(c.signature[:])
|
||||
if len(srEnabled) > 0 && srEnabled[0] {
|
||||
c.stateRootEnabled = true
|
||||
random.Fill(c.stateSig[:])
|
||||
}
|
||||
return &c
|
||||
case recoveryRequestType:
|
||||
return &recoveryRequest{timestamp: rand.Uint32()}
|
||||
case recoveryMessageType:
|
||||
return randomRecoveryMessage(t)
|
||||
return randomRecoveryMessage(t, srEnabled...)
|
||||
default:
|
||||
require.Fail(t, "invalid type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func randomPrepareRequest(t *testing.T) *prepareRequest {
|
||||
func randomPrepareRequest(t *testing.T, srEnabled ...bool) *prepareRequest {
|
||||
const txCount = 3
|
||||
|
||||
req := &prepareRequest{
|
||||
|
@ -255,15 +266,22 @@ func randomPrepareRequest(t *testing.T) *prepareRequest {
|
|||
}
|
||||
req.nextConsensus = random.Uint160()
|
||||
|
||||
if len(srEnabled) > 0 && srEnabled[0] {
|
||||
req.stateRootEnabled = true
|
||||
req.proposalStateRoot.Index = rand.Uint32()
|
||||
req.proposalStateRoot.PrevHash = random.Uint256()
|
||||
req.proposalStateRoot.Root = random.Uint256()
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func randomRecoveryMessage(t *testing.T) *recoveryMessage {
|
||||
result := randomMessage(t, prepareRequestType)
|
||||
func randomRecoveryMessage(t *testing.T, srEnabled ...bool) *recoveryMessage {
|
||||
result := randomMessage(t, prepareRequestType, srEnabled...)
|
||||
require.IsType(t, (*prepareRequest)(nil), result)
|
||||
prepReq := result.(*prepareRequest)
|
||||
|
||||
return &recoveryMessage{
|
||||
rec := &recoveryMessage{
|
||||
preparationPayloads: []*preparationCompact{
|
||||
{
|
||||
ValidatorIndex: 1,
|
||||
|
@ -297,6 +315,15 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage {
|
|||
payload: prepReq,
|
||||
},
|
||||
}
|
||||
if len(srEnabled) > 0 && srEnabled[0] {
|
||||
rec.stateRootEnabled = true
|
||||
rec.prepareRequest.stateRootEnabled = true
|
||||
for _, c := range rec.commitPayloads {
|
||||
c.stateRootEnabled = true
|
||||
random.Fill(c.StateSignature[:])
|
||||
}
|
||||
}
|
||||
return rec
|
||||
}
|
||||
|
||||
func TestPayload_Sign(t *testing.T) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package consensus
|
|||
|
||||
import (
|
||||
"github.com/nspcc-dev/dbft/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
@ -14,6 +15,9 @@ type prepareRequest struct {
|
|||
transactionHashes []util.Uint256
|
||||
minerTx transaction.Transaction
|
||||
nextConsensus util.Uint160
|
||||
proposalStateRoot state.MPTRootBase
|
||||
|
||||
stateRootEnabled bool
|
||||
}
|
||||
|
||||
var _ payload.PrepareRequest = (*prepareRequest)(nil)
|
||||
|
@ -25,6 +29,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) {
|
|||
w.WriteBytes(p.nextConsensus[:])
|
||||
w.WriteArray(p.transactionHashes)
|
||||
p.minerTx.EncodeBinary(w)
|
||||
if p.stateRootEnabled {
|
||||
p.proposalStateRoot.EncodeBinary(w)
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable interface.
|
||||
|
@ -34,6 +41,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
|
|||
r.ReadBytes(p.nextConsensus[:])
|
||||
r.ReadArray(&p.transactionHashes)
|
||||
p.minerTx.DecodeBinary(r)
|
||||
if p.stateRootEnabled {
|
||||
p.proposalStateRoot.DecodeBinary(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Timestamp implements payload.PrepareRequest interface.
|
||||
|
|
|
@ -3,6 +3,7 @@ package consensus
|
|||
import (
|
||||
"github.com/nspcc-dev/dbft/crypto"
|
||||
"github.com/nspcc-dev/dbft/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -16,6 +17,8 @@ type (
|
|||
commitPayloads []*commitCompact
|
||||
changeViewPayloads []*changeViewCompact
|
||||
prepareRequest *message
|
||||
|
||||
stateRootEnabled bool
|
||||
}
|
||||
|
||||
changeViewCompact struct {
|
||||
|
@ -29,7 +32,10 @@ type (
|
|||
ViewNumber byte
|
||||
ValidatorIndex uint16
|
||||
Signature [signatureSize]byte
|
||||
StateSignature [signatureSize]byte
|
||||
InvocationScript []byte
|
||||
|
||||
stateRootEnabled bool
|
||||
}
|
||||
|
||||
preparationCompact struct {
|
||||
|
@ -46,7 +52,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
|
|||
|
||||
var hasReq = r.ReadBool()
|
||||
if hasReq {
|
||||
m.prepareRequest = new(message)
|
||||
m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled}
|
||||
m.prepareRequest.DecodeBinary(r)
|
||||
if r.Err == nil && m.prepareRequest.Type != prepareRequestType {
|
||||
r.Err = errors.New("recovery message PrepareRequest has wrong type")
|
||||
|
@ -67,7 +73,16 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
|
|||
}
|
||||
|
||||
r.ReadArray(&m.preparationPayloads)
|
||||
r.ReadArray(&m.commitPayloads)
|
||||
lu := r.ReadVarUint()
|
||||
if lu > state.MaxValidatorsVoted {
|
||||
r.Err = errors.New("too many preparation payloads")
|
||||
return
|
||||
}
|
||||
m.commitPayloads = make([]*commitCompact, lu)
|
||||
for i := uint64(0); i < lu; i++ {
|
||||
m.commitPayloads[i] = &commitCompact{stateRootEnabled: m.stateRootEnabled}
|
||||
m.commitPayloads[i].DecodeBinary(r)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable interface.
|
||||
|
@ -96,7 +111,7 @@ func (p *changeViewCompact) DecodeBinary(r *io.BinReader) {
|
|||
p.ValidatorIndex = r.ReadU16LE()
|
||||
p.OriginalViewNumber = r.ReadB()
|
||||
p.Timestamp = r.ReadU32LE()
|
||||
p.InvocationScript = r.ReadVarBytes()
|
||||
p.InvocationScript = r.ReadVarBytes(1024)
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable interface.
|
||||
|
@ -112,7 +127,10 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) {
|
|||
p.ViewNumber = r.ReadB()
|
||||
p.ValidatorIndex = r.ReadU16LE()
|
||||
r.ReadBytes(p.Signature[:])
|
||||
p.InvocationScript = r.ReadVarBytes()
|
||||
if p.stateRootEnabled {
|
||||
r.ReadBytes(p.StateSignature[:])
|
||||
}
|
||||
p.InvocationScript = r.ReadVarBytes(1024)
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable interface.
|
||||
|
@ -120,13 +138,16 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) {
|
|||
w.WriteB(p.ViewNumber)
|
||||
w.WriteU16LE(p.ValidatorIndex)
|
||||
w.WriteBytes(p.Signature[:])
|
||||
if p.stateRootEnabled {
|
||||
w.WriteBytes(p.StateSignature[:])
|
||||
}
|
||||
w.WriteVarBytes(p.InvocationScript)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable interface.
|
||||
func (p *preparationCompact) DecodeBinary(r *io.BinReader) {
|
||||
p.ValidatorIndex = r.ReadU16LE()
|
||||
p.InvocationScript = r.ReadVarBytes()
|
||||
p.InvocationScript = r.ReadVarBytes(1024)
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable interface.
|
||||
|
@ -143,6 +164,8 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
|
|||
Type: prepareRequestType,
|
||||
ViewNumber: p.ViewNumber(),
|
||||
payload: p.GetPrepareRequest().(*prepareRequest),
|
||||
|
||||
stateRootEnabled: m.stateRootEnabled,
|
||||
}
|
||||
h := p.Hash()
|
||||
m.preparationHash = &h
|
||||
|
@ -169,9 +192,11 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
|
|||
})
|
||||
case payload.CommitType:
|
||||
m.commitPayloads = append(m.commitPayloads, &commitCompact{
|
||||
stateRootEnabled: m.stateRootEnabled,
|
||||
ValidatorIndex: p.ValidatorIndex(),
|
||||
ViewNumber: p.ViewNumber(),
|
||||
Signature: p.GetCommit().(*commit).signature,
|
||||
StateSignature: p.GetCommit().(*commit).stateSig,
|
||||
InvocationScript: p.(*Payload).Witness.InvocationScript,
|
||||
})
|
||||
}
|
||||
|
@ -234,6 +259,7 @@ func (m *recoveryMessage) GetChangeViews(p payload.ConsensusPayload, validators
|
|||
newViewNumber: cv.OriginalViewNumber + 1,
|
||||
timestamp: cv.Timestamp,
|
||||
})
|
||||
c.message.ViewNumber = cv.OriginalViewNumber
|
||||
c.SetValidatorIndex(cv.ValidatorIndex)
|
||||
c.Witness.InvocationScript = cv.InvocationScript
|
||||
c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators)
|
||||
|
@ -249,7 +275,12 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr
|
|||
ps := make([]payload.ConsensusPayload, len(m.commitPayloads))
|
||||
|
||||
for i, c := range m.commitPayloads {
|
||||
cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature})
|
||||
cc := fromPayload(commitType, p.(*Payload), &commit{
|
||||
signature: c.Signature,
|
||||
stateSig: c.StateSignature,
|
||||
|
||||
stateRootEnabled: m.stateRootEnabled,
|
||||
})
|
||||
cc.SetValidatorIndex(c.ValidatorIndex)
|
||||
cc.Witness.InvocationScript = c.InvocationScript
|
||||
cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators)
|
||||
|
@ -289,6 +320,8 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload {
|
|||
Type: t,
|
||||
ViewNumber: recovery.message.ViewNumber,
|
||||
payload: p,
|
||||
|
||||
stateRootEnabled: recovery.stateRootEnabled,
|
||||
},
|
||||
version: recovery.Version(),
|
||||
prevHash: recovery.PrevHash(),
|
||||
|
|
|
@ -13,9 +13,11 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
|
@ -229,6 +231,11 @@ func (bc *Blockchain) init() error {
|
|||
}
|
||||
bc.blockHeight = bHeight
|
||||
bc.persistedHeight = bHeight
|
||||
if bc.config.EnableStateRoot {
|
||||
if err = bc.dao.InitMPT(bHeight); err != nil {
|
||||
return errors.Wrapf(err, "can't init MPT at height %d", bHeight)
|
||||
}
|
||||
}
|
||||
|
||||
hashes, err := bc.dao.GetHeaderHashes()
|
||||
if err != nil {
|
||||
|
@ -551,6 +558,23 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 {
|
|||
return sf
|
||||
}
|
||||
|
||||
// GetStateProof returns proof of having key in the MPT with the specified root.
|
||||
func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) {
|
||||
if !bc.config.EnableStateRoot {
|
||||
return nil, errors.New("state root feature is not enabled")
|
||||
}
|
||||
tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store))
|
||||
return tr.GetProof(key)
|
||||
}
|
||||
|
||||
// GetStateRoot returns state root for a given height.
|
||||
func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
|
||||
if !bc.config.EnableStateRoot {
|
||||
return nil, errors.New("state root feature is not enabled")
|
||||
}
|
||||
return bc.dao.GetStateRoot(height)
|
||||
}
|
||||
|
||||
// TODO: storeBlock needs some more love, its implemented as in the original
|
||||
// project. This for the sake of development speed and understanding of what
|
||||
// is happening here, quite allot as you can see :). If things are wired together
|
||||
|
@ -819,6 +843,28 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
|
|||
}
|
||||
}
|
||||
|
||||
if bc.config.EnableStateRoot {
|
||||
root := bc.dao.MPT.StateRoot()
|
||||
var prevHash util.Uint256
|
||||
if block.Index > 0 {
|
||||
prev, err := bc.dao.GetStateRoot(block.Index - 1)
|
||||
if err != nil {
|
||||
return errors.WithMessagef(err, "can't get previous state root")
|
||||
}
|
||||
prevHash = hash.DoubleSha256(prev.GetSignedPart())
|
||||
}
|
||||
err := bc.AddStateRoot(&state.MPTRoot{
|
||||
MPTRootBase: state.MPTRootBase{
|
||||
Index: block.Index,
|
||||
PrevHash: prevHash,
|
||||
Root: root,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if bc.config.SaveStorageBatch {
|
||||
bc.lastBatch = cache.DAO.GetBatch()
|
||||
}
|
||||
|
@ -829,6 +875,15 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
|
|||
bc.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
if bc.config.EnableStateRoot {
|
||||
bc.dao.MPT.Flush()
|
||||
// Every persist cycle we also compact our in-memory MPT.
|
||||
persistedHeight := atomic.LoadUint32(&bc.persistedHeight)
|
||||
if persistedHeight == block.Index-1 {
|
||||
// 10 is good and roughly estimated to fit remaining trie into 1M of memory.
|
||||
bc.dao.MPT.Collapse(10)
|
||||
}
|
||||
}
|
||||
bc.topBlock.Store(block)
|
||||
atomic.StoreUint32(&bc.blockHeight, block.Index)
|
||||
bc.memPool.RemoveStale(bc.isTxStillRelevant)
|
||||
|
@ -1492,12 +1547,13 @@ func (bc *Blockchain) ApplyPolicyToTxSet(txes []mempool.TxWithFee) []mempool.TxW
|
|||
txes = txes[:bc.config.MaxTransactionsPerBlock]
|
||||
}
|
||||
maxFree := bc.config.MaxFreeTransactionsPerBlock
|
||||
if maxFree != 0 {
|
||||
lowStart := sort.Search(len(txes), func(i int) bool {
|
||||
return bc.IsLowPriority(txes[i].Fee)
|
||||
if maxFree != 0 && len(txes) > maxFree {
|
||||
// Transactions are sorted by fee, so we just find the first free one.
|
||||
freeStart := sort.Search(len(txes), func(i int) bool {
|
||||
return txes[i].Fee == 0
|
||||
})
|
||||
if lowStart+maxFree < len(txes) {
|
||||
txes = txes[:lowStart+maxFree]
|
||||
if freeStart+maxFree < len(txes) {
|
||||
txes = txes[:freeStart+maxFree]
|
||||
}
|
||||
}
|
||||
return txes
|
||||
|
@ -1732,6 +1788,90 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool {
|
|||
|
||||
}
|
||||
|
||||
// StateHeight returns height of the verified state root.
|
||||
func (bc *Blockchain) StateHeight() uint32 {
|
||||
h, _ := bc.dao.GetCurrentStateRootHeight()
|
||||
return h
|
||||
}
|
||||
|
||||
// AddStateRoot add new (possibly unverified) state root to the blockchain.
|
||||
func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error {
|
||||
if !bc.config.EnableStateRoot {
|
||||
bc.log.Warn("state root is being added but not enabled in config")
|
||||
return nil
|
||||
}
|
||||
our, err := bc.GetStateRoot(r.Index)
|
||||
if err == nil {
|
||||
if our.Flag == state.Verified {
|
||||
return bc.updateStateHeight(r.Index)
|
||||
} else if r.Witness == nil && our.Witness != nil {
|
||||
r.Witness = our.Witness
|
||||
}
|
||||
}
|
||||
if err := bc.verifyStateRoot(r); err != nil {
|
||||
return errors.WithMessage(err, "invalid state root")
|
||||
}
|
||||
if r.Index > bc.BlockHeight() { // just put it into the store for future checks
|
||||
return bc.dao.PutStateRoot(&state.MPTRootState{
|
||||
MPTRoot: *r,
|
||||
Flag: state.Unverified,
|
||||
})
|
||||
}
|
||||
|
||||
flag := state.Unverified
|
||||
if r.Witness != nil {
|
||||
if err := bc.verifyStateRootWitness(r); err != nil {
|
||||
return errors.WithMessage(err, "can't verify signature")
|
||||
}
|
||||
flag = state.Verified
|
||||
}
|
||||
err = bc.dao.PutStateRoot(&state.MPTRootState{
|
||||
MPTRoot: *r,
|
||||
Flag: flag,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bc.updateStateHeight(r.Index)
|
||||
}
|
||||
|
||||
func (bc *Blockchain) updateStateHeight(newHeight uint32) error {
|
||||
h, err := bc.dao.GetCurrentStateRootHeight()
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "can't get current state root height")
|
||||
} else if newHeight == h+1 {
|
||||
updateStateHeightMetric(newHeight)
|
||||
return bc.dao.PutCurrentStateRootHeight(h + 1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyStateRoot checks if state root is valid.
|
||||
func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error {
|
||||
if r.Index == 0 {
|
||||
return nil
|
||||
}
|
||||
prev, err := bc.GetStateRoot(r.Index - 1)
|
||||
if err != nil {
|
||||
return errors.New("can't get previous state root")
|
||||
} else if !r.PrevHash.Equals(hash.DoubleSha256(prev.GetSignedPart())) {
|
||||
return errors.New("previous hash mismatch")
|
||||
} else if prev.Version != r.Version {
|
||||
return errors.New("version mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyStateRootWitness verifies that state root signature is correct.
|
||||
func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error {
|
||||
b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil)
|
||||
return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, hash.Sha256(r.GetSignedPart()), interopCtx, true)
|
||||
}
|
||||
|
||||
// VerifyTx verifies whether a transaction is bonafide or not. Block parameter
|
||||
// is used for easy interop access and can be omitted for transactions that are
|
||||
// not yet added into any block.
|
||||
|
|
|
@ -18,6 +18,7 @@ type Blockchainer interface {
|
|||
GetConfig() config.ProtocolConfiguration
|
||||
AddHeaders(...*block.Header) error
|
||||
AddBlock(*block.Block) error
|
||||
AddStateRoot(r *state.MPTRoot) error
|
||||
BlockHeight() uint32
|
||||
CalculateClaimable(value util.Fixed8, startHeight, endHeight uint32) (util.Fixed8, util.Fixed8, error)
|
||||
Close()
|
||||
|
@ -38,6 +39,8 @@ type Blockchainer interface {
|
|||
GetNEP5Balances(util.Uint160) *state.NEP5Balances
|
||||
GetValidators(txes ...*transaction.Transaction) ([]*keys.PublicKey, error)
|
||||
GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error)
|
||||
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
||||
GetStateRoot(height uint32) (*state.MPTRootState, error)
|
||||
GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem
|
||||
GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error)
|
||||
GetTestVM() *vm.VM
|
||||
|
@ -46,6 +49,7 @@ type Blockchainer interface {
|
|||
References(t *transaction.Transaction) ([]transaction.InOut, error)
|
||||
mempool.Feer // fee interface
|
||||
PoolTx(*transaction.Transaction) error
|
||||
StateHeight() uint32
|
||||
SubscribeForBlocks(ch chan<- *block.Block)
|
||||
SubscribeForExecutions(ch chan<- *state.AppExecResult)
|
||||
SubscribeForNotifications(ch chan<- *state.NotificationEvent)
|
||||
|
|
27
pkg/consensus/cache.go → pkg/core/cache/cache.go
vendored
27
pkg/consensus/cache.go → pkg/core/cache/cache.go
vendored
|
@ -1,4 +1,4 @@
|
|||
package consensus
|
||||
package cache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
|
@ -7,9 +7,9 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// relayCache is a payload cache which is used to store
|
||||
// HashCache is a payload cache which is used to store
|
||||
// last consensus payloads.
|
||||
type relayCache struct {
|
||||
type HashCache struct {
|
||||
*sync.RWMutex
|
||||
|
||||
maxCap int
|
||||
|
@ -17,13 +17,14 @@ type relayCache struct {
|
|||
queue *list.List
|
||||
}
|
||||
|
||||
// hashable is a type of items which can be stored in the relayCache.
|
||||
type hashable interface {
|
||||
// Hashable is a type of items which can be stored in the HashCache.
|
||||
type Hashable interface {
|
||||
Hash() util.Uint256
|
||||
}
|
||||
|
||||
func newFIFOCache(capacity int) *relayCache {
|
||||
return &relayCache{
|
||||
// NewFIFOCache returns new FIFO cache with the specified capacity.
|
||||
func NewFIFOCache(capacity int) *HashCache {
|
||||
return &HashCache{
|
||||
RWMutex: new(sync.RWMutex),
|
||||
|
||||
maxCap: capacity,
|
||||
|
@ -33,7 +34,7 @@ func newFIFOCache(capacity int) *relayCache {
|
|||
}
|
||||
|
||||
// Add adds payload into a cache if it doesn't already exist.
|
||||
func (c *relayCache) Add(p hashable) {
|
||||
func (c *HashCache) Add(p Hashable) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -45,7 +46,7 @@ func (c *relayCache) Add(p hashable) {
|
|||
if c.queue.Len() >= c.maxCap {
|
||||
first := c.queue.Front()
|
||||
c.queue.Remove(first)
|
||||
delete(c.elems, first.Value.(hashable).Hash())
|
||||
delete(c.elems, first.Value.(Hashable).Hash())
|
||||
}
|
||||
|
||||
e := c.queue.PushBack(p)
|
||||
|
@ -53,7 +54,7 @@ func (c *relayCache) Add(p hashable) {
|
|||
}
|
||||
|
||||
// Has checks if an item is already in cache.
|
||||
func (c *relayCache) Has(h util.Uint256) bool {
|
||||
func (c *HashCache) Has(h util.Uint256) bool {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
|
@ -61,13 +62,13 @@ func (c *relayCache) Has(h util.Uint256) bool {
|
|||
}
|
||||
|
||||
// Get returns payload with the specified hash from cache.
|
||||
func (c *relayCache) Get(h util.Uint256) hashable {
|
||||
func (c *HashCache) Get(h util.Uint256) Hashable {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
e, ok := c.elems[h]
|
||||
if !ok {
|
||||
return hashable(nil)
|
||||
return Hashable(nil)
|
||||
}
|
||||
return e.Value.(hashable)
|
||||
return e.Value.(Hashable)
|
||||
}
|
|
@ -1,17 +1,19 @@
|
|||
package consensus
|
||||
package cache
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/dbft/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRelayCache_Add(t *testing.T) {
|
||||
const capacity = 3
|
||||
payloads := getDifferentPayloads(t, capacity+1)
|
||||
c := newFIFOCache(capacity)
|
||||
payloads := getDifferentItems(t, capacity+1)
|
||||
c := NewFIFOCache(capacity)
|
||||
require.Equal(t, 0, c.queue.Len())
|
||||
require.Equal(t, 0, len(c.elems))
|
||||
|
||||
|
@ -46,19 +48,15 @@ func TestRelayCache_Add(t *testing.T) {
|
|||
require.Equal(t, nil, c.Get(payloads[1].Hash()))
|
||||
}
|
||||
|
||||
func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) {
|
||||
payloads = make([]Payload, n)
|
||||
for i := range payloads {
|
||||
var sign [signatureSize]byte
|
||||
random.Fill(sign[:])
|
||||
type testHashable []byte
|
||||
|
||||
payloads[i].message = &message{}
|
||||
payloads[i].SetValidatorIndex(uint16(i))
|
||||
payloads[i].SetType(payload.MessageType(commitType))
|
||||
payloads[i].payload = &commit{
|
||||
signature: sign,
|
||||
}
|
||||
// Hash implements Hashable.
|
||||
func (h testHashable) Hash() util.Uint256 { return hash.Sha256(h) }
|
||||
|
||||
func getDifferentItems(t *testing.T, n int) []testHashable {
|
||||
items := make([]testHashable, n)
|
||||
for i := range items {
|
||||
items[i] = random.Bytes(rand.Int() % 10)
|
||||
}
|
||||
|
||||
return
|
||||
return items
|
||||
}
|
|
@ -245,6 +245,7 @@ func (cd *Cached) FlushStorage() error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
ti.State |= flushedState
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -275,7 +276,7 @@ func (cd *Cached) getStorageItemNoCache(scripthash util.Uint160, key []byte) *st
|
|||
func (cd *Cached) getStorageItemInt(scripthash util.Uint160, key []byte, putToCache bool) *state.StorageItem {
|
||||
ti := cd.storage.getItem(scripthash, key)
|
||||
if ti != nil {
|
||||
if ti.State == delOp {
|
||||
if ti.State&delOp != 0 {
|
||||
return nil
|
||||
}
|
||||
return copyItem(&ti.StorageItem)
|
||||
|
@ -303,8 +304,10 @@ func (cd *Cached) PutStorageItem(scripthash util.Uint160, key []byte, si *state.
|
|||
item := copyItem(si)
|
||||
ti := cd.storage.getItem(scripthash, key)
|
||||
if ti != nil {
|
||||
if ti.State == delOp || ti.State == getOp {
|
||||
if ti.State&(delOp|getOp) != 0 {
|
||||
ti.State = putOp
|
||||
} else {
|
||||
ti.State = addOp
|
||||
}
|
||||
ti.StorageItem = *item
|
||||
return nil
|
||||
|
@ -357,7 +360,7 @@ func (cd *Cached) GetStorageItemsIterator(hash util.Uint160, prefix []byte) (Sto
|
|||
for ; keyIndex < len(cd.storage.keys[hash]); keyIndex++ {
|
||||
k := cd.storage.keys[hash][keyIndex]
|
||||
v := cache[k]
|
||||
if v.State != delOp && bytes.HasPrefix([]byte(k), prefix) {
|
||||
if v.State&delOp == 0 && bytes.HasPrefix([]byte(k), prefix) {
|
||||
val := make([]byte, len(v.StorageItem.Value))
|
||||
copy(val, v.StorageItem.Value)
|
||||
return []byte(k), val, nil
|
||||
|
@ -404,7 +407,7 @@ func (cd *Cached) GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageIt
|
|||
|
||||
for _, k := range cd.storage.keys[hash] {
|
||||
v := cache[k]
|
||||
if v.State != delOp {
|
||||
if v.State&delOp == 0 {
|
||||
val := make([]byte, len(v.StorageItem.Value))
|
||||
copy(val, v.StorageItem.Value)
|
||||
result = append(result, StorageItemWithKey{
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
|
@ -31,9 +32,12 @@ type DAO interface {
|
|||
GetContractState(hash util.Uint160) (*state.Contract, error)
|
||||
GetCurrentBlockHeight() (uint32, error)
|
||||
GetCurrentHeaderHeight() (i uint32, h util.Uint256, err error)
|
||||
GetCurrentStateRootHeight() (uint32, error)
|
||||
GetHeaderHashes() ([]util.Uint256, error)
|
||||
GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error)
|
||||
GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error)
|
||||
GetStateRoot(height uint32) (*state.MPTRootState, error)
|
||||
PutStateRoot(root *state.MPTRootState) error
|
||||
GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem
|
||||
GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageItemWithKey, error)
|
||||
GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error)
|
||||
|
@ -70,12 +74,14 @@ type DAO interface {
|
|||
|
||||
// Simple is memCached wrapper around DB, simple DAO implementation.
|
||||
type Simple struct {
|
||||
MPT *mpt.Trie
|
||||
Store *storage.MemCachedStore
|
||||
}
|
||||
|
||||
// NewSimple creates new simple dao using provided backend store.
|
||||
func NewSimple(backend storage.Store) *Simple {
|
||||
return &Simple{Store: storage.NewMemCachedStore(backend)}
|
||||
st := storage.NewMemCachedStore(backend)
|
||||
return &Simple{Store: st, MPT: mpt.NewTrie(nil, st)}
|
||||
}
|
||||
|
||||
// GetBatch returns currently accumulated DB changeset.
|
||||
|
@ -86,7 +92,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch {
|
|||
// GetWrapped returns new DAO instance with another layer of wrapped
|
||||
// MemCachedStore around the current DAO Store.
|
||||
func (dao *Simple) GetWrapped() DAO {
|
||||
return NewSimple(dao.Store)
|
||||
d := NewSimple(dao.Store)
|
||||
d.MPT = dao.MPT
|
||||
return d
|
||||
}
|
||||
|
||||
// GetAndDecode performs get operation and decoding with serializable structures.
|
||||
|
@ -406,6 +414,63 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error {
|
|||
|
||||
// -- start storage item.
|
||||
|
||||
func makeStateRootKey(height uint32) []byte {
|
||||
key := make([]byte, 5)
|
||||
key[0] = byte(storage.DataMPT)
|
||||
binary.LittleEndian.PutUint32(key[1:], height)
|
||||
return key
|
||||
}
|
||||
|
||||
// InitMPT initializes MPT at the given height.
|
||||
func (dao *Simple) InitMPT(height uint32) error {
|
||||
if height == 0 {
|
||||
dao.MPT = mpt.NewTrie(nil, dao.Store)
|
||||
return nil
|
||||
}
|
||||
r, err := dao.GetStateRoot(height)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentStateRootHeight returns current state root height.
|
||||
func (dao *Simple) GetCurrentStateRootHeight() (uint32, error) {
|
||||
key := []byte{byte(storage.DataMPT)}
|
||||
val, err := dao.Store.Get(key)
|
||||
if err != nil {
|
||||
if err == storage.ErrKeyNotFound {
|
||||
err = nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return binary.LittleEndian.Uint32(val), nil
|
||||
}
|
||||
|
||||
// PutCurrentStateRootHeight updates current state root height.
|
||||
func (dao *Simple) PutCurrentStateRootHeight(height uint32) error {
|
||||
key := []byte{byte(storage.DataMPT)}
|
||||
val := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(val, height)
|
||||
return dao.Store.Put(key, val)
|
||||
}
|
||||
|
||||
// GetStateRoot returns state root of a given height.
|
||||
func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) {
|
||||
r := new(state.MPTRootState)
|
||||
err := dao.GetAndDecode(r, makeStateRootKey(height))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutStateRoot puts state root of a given height into the store.
|
||||
func (dao *Simple) PutStateRoot(r *state.MPTRootState) error {
|
||||
return dao.Put(r, makeStateRootKey(r.Index))
|
||||
}
|
||||
|
||||
// GetStorageItem returns StorageItem if it exists in the given store.
|
||||
func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem {
|
||||
b, err := dao.Store.Get(makeStorageItemKey(scripthash, key))
|
||||
|
@ -426,13 +491,24 @@ func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.St
|
|||
// PutStorageItem puts given StorageItem for given script with given
|
||||
// key into the given store.
|
||||
func (dao *Simple) PutStorageItem(scripthash util.Uint160, key []byte, si *state.StorageItem) error {
|
||||
return dao.Put(si, makeStorageItemKey(scripthash, key))
|
||||
stKey := makeStorageItemKey(scripthash, key)
|
||||
k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix
|
||||
v := mpt.ToNeoStorageValue(si)
|
||||
if err := dao.MPT.Put(k, v); err != nil && err != mpt.ErrNotFound {
|
||||
return err
|
||||
}
|
||||
return dao.Store.Put(stKey, v[1:])
|
||||
}
|
||||
|
||||
// DeleteStorageItem drops storage item for the given script with the
|
||||
// given key from the store.
|
||||
func (dao *Simple) DeleteStorageItem(scripthash util.Uint160, key []byte) error {
|
||||
return dao.Store.Delete(makeStorageItemKey(scripthash, key))
|
||||
stKey := makeStorageItemKey(scripthash, key)
|
||||
k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix
|
||||
if err := dao.MPT.Delete(k); err != nil && err != mpt.ErrNotFound {
|
||||
return err
|
||||
}
|
||||
return dao.Store.Delete(stKey)
|
||||
}
|
||||
|
||||
// StorageItemWithKey is a Key-Value pair together with possible const modifier.
|
||||
|
|
|
@ -24,6 +24,7 @@ const (
|
|||
delOp
|
||||
addOp
|
||||
putOp
|
||||
flushedState
|
||||
)
|
||||
|
||||
func newItemCache() *itemCache {
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
|
@ -600,6 +603,33 @@ func (ic *interopContext) contractMigrate(v *vm.VM) error {
|
|||
return ic.contractDestroy(v)
|
||||
}
|
||||
|
||||
// secp256k1Recover recovers speck256k1 public key.
|
||||
func (ic *interopContext) secp256k1Recover(v *vm.VM) error {
|
||||
return ic.eccRecover(btcec.S256(), v)
|
||||
}
|
||||
|
||||
// secp256r1Recover recovers speck256r1 public key.
|
||||
func (ic *interopContext) secp256r1Recover(v *vm.VM) error {
|
||||
return ic.eccRecover(elliptic.P256(), v)
|
||||
}
|
||||
|
||||
// eccRecover recovers public key using ECCurve set
|
||||
func (ic *interopContext) eccRecover(curve elliptic.Curve, v *vm.VM) error {
|
||||
rBytes := v.Estack().Pop().Bytes()
|
||||
sBytes := v.Estack().Pop().Bytes()
|
||||
r := new(big.Int).SetBytes(rBytes)
|
||||
s := new(big.Int).SetBytes(sBytes)
|
||||
isEven := v.Estack().Pop().Bool()
|
||||
messageHash := v.Estack().Pop().Bytes()
|
||||
pKey, err := keys.KeyRecover(curve, r, s, messageHash, isEven)
|
||||
if err != nil {
|
||||
v.Estack().PushVal([]byte{})
|
||||
return nil
|
||||
}
|
||||
v.Estack().PushVal(pKey.UncompressedBytes()[1:])
|
||||
return nil
|
||||
}
|
||||
|
||||
// assetCreate creates an asset.
|
||||
func (ic *interopContext) assetCreate(v *vm.VM) error {
|
||||
if ic.trigger != trigger.Application {
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
|
@ -457,8 +460,81 @@ func TestAssetGetPrecision(t *testing.T) {
|
|||
require.Equal(t, big.NewInt(int64(assetState.Precision)), precision)
|
||||
}
|
||||
|
||||
func TestSecp256k1Recover(t *testing.T) {
|
||||
v, context, chain := createVM(t)
|
||||
defer chain.Close()
|
||||
|
||||
privateKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
require.NoError(t, err)
|
||||
message := []byte("The quick brown fox jumps over the lazy dog")
|
||||
signature, err := privateKey.Sign(message)
|
||||
require.NoError(t, err)
|
||||
require.True(t, signature.Verify(message, privateKey.PubKey()))
|
||||
pubKey := keys.PublicKey{
|
||||
X: privateKey.PubKey().X,
|
||||
Y: privateKey.PubKey().Y,
|
||||
}
|
||||
expected := pubKey.UncompressedBytes()[1:]
|
||||
|
||||
// We don't know which of two recovered keys suites, so let's try both.
|
||||
putOnStackGetResult := func(isEven bool) []byte {
|
||||
v.Estack().PushVal(message)
|
||||
v.Estack().PushVal(isEven)
|
||||
v.Estack().PushVal(signature.S.Bytes())
|
||||
v.Estack().PushVal(signature.R.Bytes())
|
||||
err = context.secp256k1Recover(v)
|
||||
require.NoError(t, err)
|
||||
return v.Estack().Pop().Value().([]byte)
|
||||
}
|
||||
|
||||
// First one:
|
||||
actualFalse := putOnStackGetResult(false)
|
||||
// Second one:
|
||||
actualTrue := putOnStackGetResult(true)
|
||||
|
||||
require.True(t, bytes.Compare(expected, actualFalse) != bytes.Compare(expected, actualTrue))
|
||||
}
|
||||
|
||||
func TestSecp256r1Recover(t *testing.T) {
|
||||
v, context, chain := createVM(t)
|
||||
defer chain.Close()
|
||||
|
||||
privateKey, err := keys.NewPrivateKey()
|
||||
require.NoError(t, err)
|
||||
message := []byte("The quick brown fox jumps over the lazy dog")
|
||||
messageHash := hash.Sha256(message).BytesBE()
|
||||
signature := privateKey.Sign(message)
|
||||
require.True(t, privateKey.PublicKey().Verify(signature, messageHash))
|
||||
expected := privateKey.PublicKey().UncompressedBytes()[1:]
|
||||
|
||||
// We don't know which of two recovered keys suites, so let's try both.
|
||||
putOnStackGetResult := func(isEven bool) []byte {
|
||||
v.Estack().PushVal(messageHash)
|
||||
v.Estack().PushVal(isEven)
|
||||
v.Estack().PushVal(signature[32:64])
|
||||
v.Estack().PushVal(signature[0:32])
|
||||
err = context.secp256r1Recover(v)
|
||||
require.NoError(t, err)
|
||||
return v.Estack().Pop().Value().([]byte)
|
||||
}
|
||||
|
||||
// First one:
|
||||
actualFalse := putOnStackGetResult(false)
|
||||
// Second one:
|
||||
actualTrue := putOnStackGetResult(true)
|
||||
|
||||
require.True(t, bytes.Compare(expected, actualFalse) != bytes.Compare(expected, actualTrue))
|
||||
}
|
||||
|
||||
// Helper functions to create VM, InteropContext, TX, Account, Contract, Asset.
|
||||
|
||||
func createVM(t *testing.T) (*vm.VM, *interopContext, *Blockchain) {
|
||||
v := vm.New()
|
||||
chain := newTestChain(t)
|
||||
context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore()), nil, nil)
|
||||
return v, context, chain
|
||||
}
|
||||
|
||||
func createVMAndPushBlock(t *testing.T) (*vm.VM, *block.Block, *interopContext, *Blockchain) {
|
||||
v := vm.New()
|
||||
block := newDumbBlock()
|
||||
|
|
|
@ -52,6 +52,9 @@ func (ic *interopContext) SpawnVM() *vm.VM {
|
|||
})
|
||||
vm.RegisterInteropGetter(ic.getSystemInterop)
|
||||
vm.RegisterInteropGetter(ic.getNeoInterop)
|
||||
if ic.bc != nil && ic.bc.GetConfig().EnableStateRoot {
|
||||
vm.RegisterInteropGetter(ic.getNeoxInterop)
|
||||
}
|
||||
return vm
|
||||
}
|
||||
|
||||
|
@ -77,6 +80,12 @@ func (ic *interopContext) getNeoInterop(id uint32) *vm.InteropFuncPrice {
|
|||
return ic.getInteropFromSlice(id, neoInterops)
|
||||
}
|
||||
|
||||
// getNeoxInterop returns matching interop function from the NeoX extension
|
||||
// for a given id in the current context.
|
||||
func (ic *interopContext) getNeoxInterop(id uint32) *vm.InteropFuncPrice {
|
||||
return ic.getInteropFromSlice(id, neoxInterops)
|
||||
}
|
||||
|
||||
// getInteropFromSlice returns matching interop function from the given slice of
|
||||
// interop functions in the current context.
|
||||
func (ic *interopContext) getInteropFromSlice(id uint32, slice []interopedFunction) *vm.InteropFuncPrice {
|
||||
|
@ -276,6 +285,11 @@ var neoInterops = []interopedFunction{
|
|||
{Name: "AntShares.Transaction.GetType", Func: (*interopContext).txGetType, Price: 1},
|
||||
}
|
||||
|
||||
var neoxInterops = []interopedFunction{
|
||||
{Name: "Neo.Cryptography.Secp256k1Recover", Func: (*interopContext).secp256k1Recover, Price: 100},
|
||||
{Name: "Neo.Cryptography.Secp256r1Recover", Func: (*interopContext).secp256r1Recover, Price: 100},
|
||||
}
|
||||
|
||||
// initIDinInteropsSlice initializes IDs from names in one given
|
||||
// interopedFunction slice and then sorts it.
|
||||
func initIDinInteropsSlice(iops []interopedFunction) {
|
||||
|
@ -291,4 +305,5 @@ func initIDinInteropsSlice(iops []interopedFunction) {
|
|||
func init() {
|
||||
initIDinInteropsSlice(systemInterops)
|
||||
initIDinInteropsSlice(neoInterops)
|
||||
initIDinInteropsSlice(neoxInterops)
|
||||
}
|
||||
|
|
84
pkg/core/mpt/base.go
Normal file
84
pkg/core/mpt/base.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// BaseNode implements basic things every node needs like caching hash and
|
||||
// serialized representation. It's a basic node building block intended to be
|
||||
// included into all node types.
|
||||
type BaseNode struct {
|
||||
hash util.Uint256
|
||||
bytes []byte
|
||||
hashValid bool
|
||||
bytesValid bool
|
||||
|
||||
isFlushed bool
|
||||
}
|
||||
|
||||
// BaseNodeIface abstracts away basic Node functions.
|
||||
type BaseNodeIface interface {
|
||||
Hash() util.Uint256
|
||||
Type() NodeType
|
||||
Bytes() []byte
|
||||
IsFlushed() bool
|
||||
SetFlushed()
|
||||
}
|
||||
|
||||
// getHash returns a hash of this BaseNode.
|
||||
func (b *BaseNode) getHash(n Node) util.Uint256 {
|
||||
if !b.hashValid {
|
||||
b.updateHash(n)
|
||||
}
|
||||
return b.hash
|
||||
}
|
||||
|
||||
// getBytes returns a slice of bytes representing this node.
|
||||
func (b *BaseNode) getBytes(n Node) []byte {
|
||||
if !b.bytesValid {
|
||||
b.updateBytes(n)
|
||||
}
|
||||
return b.bytes
|
||||
}
|
||||
|
||||
// updateHash updates hash field for this BaseNode.
|
||||
func (b *BaseNode) updateHash(n Node) {
|
||||
if n.Type() == HashT {
|
||||
panic("can't update hash for hash node")
|
||||
}
|
||||
b.hash = hash.DoubleSha256(b.getBytes(n))
|
||||
b.hashValid = true
|
||||
}
|
||||
|
||||
// updateCache updates hash and bytes fields for this BaseNode.
|
||||
func (b *BaseNode) updateBytes(n Node) {
|
||||
buf := io.NewBufBinWriter()
|
||||
encodeNodeWithType(n, buf.BinWriter)
|
||||
b.bytes = buf.Bytes()
|
||||
b.bytesValid = true
|
||||
}
|
||||
|
||||
// invalidateCache sets all cache fields to invalid state.
|
||||
func (b *BaseNode) invalidateCache() {
|
||||
b.bytesValid = false
|
||||
b.hashValid = false
|
||||
b.isFlushed = false
|
||||
}
|
||||
|
||||
// IsFlushed checks for node flush status.
|
||||
func (b *BaseNode) IsFlushed() bool {
|
||||
return b.isFlushed
|
||||
}
|
||||
|
||||
// SetFlushed sets 'flushed' flag to true for this node.
|
||||
func (b *BaseNode) SetFlushed() {
|
||||
b.isFlushed = true
|
||||
}
|
||||
|
||||
// encodeNodeWithType encodes node together with it's type.
|
||||
func encodeNodeWithType(n Node, w *io.BinWriter) {
|
||||
w.WriteB(byte(n.Type()))
|
||||
n.EncodeBinary(w)
|
||||
}
|
91
pkg/core/mpt/branch.go
Normal file
91
pkg/core/mpt/branch.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// childrenCount represents a number of children of a branch node.
|
||||
childrenCount = 17
|
||||
// lastChild is the index of the last child.
|
||||
lastChild = childrenCount - 1
|
||||
)
|
||||
|
||||
// BranchNode represents MPT's branch node.
|
||||
type BranchNode struct {
|
||||
BaseNode
|
||||
Children [childrenCount]Node
|
||||
}
|
||||
|
||||
var _ Node = (*BranchNode)(nil)
|
||||
|
||||
// NewBranchNode returns new branch node.
|
||||
func NewBranchNode() *BranchNode {
|
||||
b := new(BranchNode)
|
||||
for i := 0; i < childrenCount; i++ {
|
||||
b.Children[i] = new(HashNode)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Type implements Node interface.
|
||||
func (b *BranchNode) Type() NodeType { return BranchT }
|
||||
|
||||
// Hash implements BaseNode interface.
|
||||
func (b *BranchNode) Hash() util.Uint256 {
|
||||
return b.getHash(b)
|
||||
}
|
||||
|
||||
// Bytes implements BaseNode interface.
|
||||
func (b *BranchNode) Bytes() []byte {
|
||||
return b.getBytes(b)
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (b *BranchNode) EncodeBinary(w *io.BinWriter) {
|
||||
for i := 0; i < childrenCount; i++ {
|
||||
if hn, ok := b.Children[i].(*HashNode); ok {
|
||||
hn.EncodeBinary(w)
|
||||
continue
|
||||
}
|
||||
n := NewHashNode(b.Children[i].Hash())
|
||||
n.EncodeBinary(w)
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (b *BranchNode) DecodeBinary(r *io.BinReader) {
|
||||
for i := 0; i < childrenCount; i++ {
|
||||
b.Children[i] = new(HashNode)
|
||||
b.Children[i].DecodeBinary(r)
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (b *BranchNode) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(b.Children)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (b *BranchNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*BranchNode); ok {
|
||||
*b = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected branch node")
|
||||
}
|
||||
|
||||
// splitPath splits path for a branch node.
|
||||
func splitPath(path []byte) (byte, []byte) {
|
||||
if len(path) != 0 {
|
||||
return path[0], path[1:]
|
||||
}
|
||||
return lastChild, path
|
||||
}
|
45
pkg/core/mpt/doc.go
Normal file
45
pkg/core/mpt/doc.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
Package mpt implements MPT (Merkle-Patricia Tree).
|
||||
|
||||
MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie
|
||||
Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node.
|
||||
MPT consists of 4 type of nodes:
|
||||
- Leaf node contains only value.
|
||||
- Extension node contains both key and value.
|
||||
- Branch node contains 2 or more children.
|
||||
- Hash node is a compressed node and contains only actual node's hash.
|
||||
The actual node must be retrieved from storage or over the network.
|
||||
|
||||
As an example here is a trie containing 3 pairs:
|
||||
- 0x1201 -> val1
|
||||
- 0x1203 -> val2
|
||||
- 0x1224 -> val3
|
||||
- 0x12 -> val4
|
||||
|
||||
ExtensionNode(0x0102), Next
|
||||
_______________________|
|
||||
|
|
||||
BranchNode [0, 1, 2, ...], Last -> Leaf(val4)
|
||||
| |
|
||||
| ExtensionNode [0x04], Next -> Leaf(val3)
|
||||
|
|
||||
BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil)
|
||||
| |
|
||||
| Leaf(val2)
|
||||
|
|
||||
Leaf(val1)
|
||||
|
||||
There are 3 invariants that this implementation has:
|
||||
- Branch node cannot have <= 1 children
|
||||
- Extension node cannot have zero-length key
|
||||
- Extension node cannot have another Extension node in it's next field
|
||||
|
||||
Thank to these restrictions, there is a single root hash for every set of key-value pairs
|
||||
irregardless of the order they were added/removed with.
|
||||
The actual trie structure can vary because of node -> HashNode compressing.
|
||||
|
||||
There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial:
|
||||
When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is
|
||||
replaced by its uncompressed form, so that subsequent hits of this not don't use storage.
|
||||
*/
|
||||
package mpt
|
87
pkg/core/mpt/extension.go
Normal file
87
pkg/core/mpt/extension.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// MaxKeyLength is the max length of the extension node key.
|
||||
const MaxKeyLength = 1125
|
||||
|
||||
// ExtensionNode represents MPT's extension node.
|
||||
type ExtensionNode struct {
|
||||
BaseNode
|
||||
key []byte
|
||||
next Node
|
||||
}
|
||||
|
||||
var _ Node = (*ExtensionNode)(nil)
|
||||
|
||||
// NewExtensionNode returns hash node with the specified key and next node.
|
||||
// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0.
|
||||
func NewExtensionNode(key []byte, next Node) *ExtensionNode {
|
||||
return &ExtensionNode{
|
||||
key: key,
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
// Type implements Node interface.
|
||||
func (e ExtensionNode) Type() NodeType { return ExtensionT }
|
||||
|
||||
// Hash implements BaseNode interface.
|
||||
func (e *ExtensionNode) Hash() util.Uint256 {
|
||||
return e.getHash(e)
|
||||
}
|
||||
|
||||
// Bytes implements BaseNode interface.
|
||||
func (e *ExtensionNode) Bytes() []byte {
|
||||
return e.getBytes(e)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (e *ExtensionNode) DecodeBinary(r *io.BinReader) {
|
||||
sz := r.ReadVarUint()
|
||||
if sz > MaxKeyLength {
|
||||
r.Err = fmt.Errorf("extension node key is too big: %d", sz)
|
||||
return
|
||||
}
|
||||
e.key = make([]byte, sz)
|
||||
r.ReadBytes(e.key)
|
||||
e.next = new(HashNode)
|
||||
e.next.DecodeBinary(r)
|
||||
e.invalidateCache()
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (e ExtensionNode) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteVarBytes(e.key)
|
||||
n := NewHashNode(e.next.Hash())
|
||||
n.EncodeBinary(w)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (e *ExtensionNode) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]interface{}{
|
||||
"key": hex.EncodeToString(e.key),
|
||||
"next": e.next,
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*ExtensionNode); ok {
|
||||
*e = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected extension node")
|
||||
}
|
88
pkg/core/mpt/hash.go
Normal file
88
pkg/core/mpt/hash.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// HashNode represents MPT's hash node.
|
||||
type HashNode struct {
|
||||
BaseNode
|
||||
}
|
||||
|
||||
var _ Node = (*HashNode)(nil)
|
||||
|
||||
// NewHashNode returns hash node with the specified hash.
|
||||
func NewHashNode(h util.Uint256) *HashNode {
|
||||
return &HashNode{
|
||||
BaseNode: BaseNode{
|
||||
hash: h,
|
||||
hashValid: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Type implements Node interface.
|
||||
func (h *HashNode) Type() NodeType { return HashT }
|
||||
|
||||
// Hash implements Node interface.
|
||||
func (h *HashNode) Hash() util.Uint256 {
|
||||
if !h.hashValid {
|
||||
panic("can't get hash of an empty HashNode")
|
||||
}
|
||||
return h.hash
|
||||
}
|
||||
|
||||
// IsEmpty returns true iff h is an empty node i.e. contains no hash.
|
||||
func (h *HashNode) IsEmpty() bool { return !h.hashValid }
|
||||
|
||||
// Bytes returns serialized HashNode.
|
||||
func (h *HashNode) Bytes() []byte {
|
||||
return h.getBytes(h)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (h *HashNode) DecodeBinary(r *io.BinReader) {
|
||||
sz := r.ReadVarUint()
|
||||
switch sz {
|
||||
case 0:
|
||||
h.hashValid = false
|
||||
case util.Uint256Size:
|
||||
h.hashValid = true
|
||||
r.ReadBytes(h.hash[:])
|
||||
default:
|
||||
r.Err = fmt.Errorf("invalid hash node size: %d", sz)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (h HashNode) EncodeBinary(w *io.BinWriter) {
|
||||
if !h.hashValid {
|
||||
w.WriteVarUint(0)
|
||||
return
|
||||
}
|
||||
w.WriteVarBytes(h.hash[:])
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (h *HashNode) MarshalJSON() ([]byte, error) {
|
||||
if !h.hashValid {
|
||||
return []byte(`{}`), nil
|
||||
}
|
||||
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (h *HashNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*HashNode); ok {
|
||||
*h = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected hash node")
|
||||
}
|
86
pkg/core/mpt/helpers.go
Normal file
86
pkg/core/mpt/helpers.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// lcp returns longest common prefix of a and b.
|
||||
// Note: it does no allocations.
|
||||
func lcp(a, b []byte) []byte {
|
||||
if len(a) < len(b) {
|
||||
return lcp(b, a)
|
||||
}
|
||||
|
||||
var i int
|
||||
for i = 0; i < len(b); i++ {
|
||||
if a[i] != b[i] {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return a[:i]
|
||||
}
|
||||
|
||||
// copySlice is a helper for copying slice if needed.
|
||||
func copySlice(a []byte) []byte {
|
||||
b := make([]byte, len(a))
|
||||
copy(b, a)
|
||||
return b
|
||||
}
|
||||
|
||||
// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part.
|
||||
func toNibbles(path []byte) []byte {
|
||||
result := make([]byte, len(path)*2)
|
||||
for i := range path {
|
||||
result[i*2] = path[i] >> 4
|
||||
result[i*2+1] = path[i] & 0x0F
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToNeoStorageKey converts storage key to C# neo node's format.
|
||||
// Key is expected to be at least 20 bytes in length.
|
||||
// our format: script hash in BE + key
|
||||
// neo format: script hash in LE + key with 0 between every 16 bytes, padded to len 16.
|
||||
func ToNeoStorageKey(key []byte) []byte {
|
||||
const groupSize = 16
|
||||
|
||||
var nkey []byte
|
||||
for i := util.Uint160Size - 1; i >= 0; i-- {
|
||||
nkey = append(nkey, key[i])
|
||||
}
|
||||
|
||||
key = key[util.Uint160Size:]
|
||||
|
||||
index := 0
|
||||
remain := len(key)
|
||||
for remain >= groupSize {
|
||||
nkey = append(nkey, key[index:index+groupSize]...)
|
||||
nkey = append(nkey, 0)
|
||||
index += groupSize
|
||||
remain -= groupSize
|
||||
}
|
||||
|
||||
if remain > 0 {
|
||||
nkey = append(nkey, key[index:]...)
|
||||
}
|
||||
|
||||
padding := groupSize - remain
|
||||
for i := 0; i < padding; i++ {
|
||||
nkey = append(nkey, 0)
|
||||
}
|
||||
return append(nkey, byte(padding))
|
||||
}
|
||||
|
||||
// ToNeoStorageValue serializes si to a C# neo node's format.
|
||||
// It has additional version (0x00) byte at the beginning.
|
||||
func ToNeoStorageValue(si *state.StorageItem) []byte {
|
||||
const version = 0
|
||||
|
||||
buf := io.NewBufBinWriter()
|
||||
buf.BinWriter.WriteB(version)
|
||||
si.EncodeBinary(buf.BinWriter)
|
||||
return buf.Bytes()
|
||||
}
|
30
pkg/core/mpt/helpers_test.go
Normal file
30
pkg/core/mpt/helpers_test.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToNeoStorageKey(t *testing.T) {
|
||||
testCases := []struct{ key, res string }{
|
||||
{
|
||||
"0102030405060708091011121314151617181920",
|
||||
"20191817161514131211100908070605040302010000000000000000000000000000000010",
|
||||
},
|
||||
{
|
||||
"010203040506070809101112131415161718192021222324",
|
||||
"2019181716151413121110090807060504030201212223240000000000000000000000000c",
|
||||
},
|
||||
{
|
||||
"0102030405060708091011121314151617181920212223242526272829303132333435363738",
|
||||
"20191817161514131211100908070605040302012122232425262728293031323334353600373800000000000000000000000000000e",
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
key, _ := hex.DecodeString(tc.key)
|
||||
res, _ := hex.DecodeString(tc.res)
|
||||
require.Equal(t, res, ToNeoStorageKey(key))
|
||||
}
|
||||
}
|
73
pkg/core/mpt/leaf.go
Normal file
73
pkg/core/mpt/leaf.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// MaxValueLength is a max length of a leaf node value.
|
||||
const MaxValueLength = 1024 * 1024
|
||||
|
||||
// LeafNode represents MPT's leaf node.
|
||||
type LeafNode struct {
|
||||
BaseNode
|
||||
value []byte
|
||||
}
|
||||
|
||||
var _ Node = (*LeafNode)(nil)
|
||||
|
||||
// NewLeafNode returns hash node with the specified value.
|
||||
func NewLeafNode(value []byte) *LeafNode {
|
||||
return &LeafNode{value: value}
|
||||
}
|
||||
|
||||
// Type implements Node interface.
|
||||
func (n LeafNode) Type() NodeType { return LeafT }
|
||||
|
||||
// Hash implements BaseNode interface.
|
||||
func (n *LeafNode) Hash() util.Uint256 {
|
||||
return n.getHash(n)
|
||||
}
|
||||
|
||||
// Bytes implements BaseNode interface.
|
||||
func (n *LeafNode) Bytes() []byte {
|
||||
return n.getBytes(n)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (n *LeafNode) DecodeBinary(r *io.BinReader) {
|
||||
sz := r.ReadVarUint()
|
||||
if sz > MaxValueLength {
|
||||
r.Err = fmt.Errorf("leaf node value is too big: %d", sz)
|
||||
return
|
||||
}
|
||||
n.value = make([]byte, sz)
|
||||
r.ReadBytes(n.value)
|
||||
n.invalidateCache()
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (n LeafNode) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteVarBytes(n.value)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (n *LeafNode) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (n *LeafNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*LeafNode); ok {
|
||||
*n = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected leaf node")
|
||||
}
|
134
pkg/core/mpt/node.go
Normal file
134
pkg/core/mpt/node.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// NodeType represents node type..
|
||||
type NodeType byte
|
||||
|
||||
// Node types definitions.
|
||||
const (
|
||||
BranchT NodeType = 0x00
|
||||
ExtensionT NodeType = 0x01
|
||||
HashT NodeType = 0x02
|
||||
LeafT NodeType = 0x03
|
||||
)
|
||||
|
||||
// NodeObject represents Node together with it's type.
|
||||
// It is used for serialization/deserialization where type info
|
||||
// is also expected.
|
||||
type NodeObject struct {
|
||||
Node
|
||||
}
|
||||
|
||||
// Node represents common interface of all MPT nodes.
|
||||
type Node interface {
|
||||
io.Serializable
|
||||
json.Marshaler
|
||||
json.Unmarshaler
|
||||
BaseNodeIface
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (n NodeObject) EncodeBinary(w *io.BinWriter) {
|
||||
encodeNodeWithType(n.Node, w)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (n *NodeObject) DecodeBinary(r *io.BinReader) {
|
||||
typ := NodeType(r.ReadB())
|
||||
switch typ {
|
||||
case BranchT:
|
||||
n.Node = new(BranchNode)
|
||||
case ExtensionT:
|
||||
n.Node = new(ExtensionNode)
|
||||
case HashT:
|
||||
n.Node = new(HashNode)
|
||||
case LeafT:
|
||||
n.Node = new(LeafNode)
|
||||
default:
|
||||
r.Err = fmt.Errorf("invalid node type: %x", typ)
|
||||
return
|
||||
}
|
||||
n.Node.DecodeBinary(r)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (n *NodeObject) UnmarshalJSON(data []byte) error {
|
||||
var m map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &m)
|
||||
if err != nil { // it can be a branch node
|
||||
var nodes []NodeObject
|
||||
if err := json.Unmarshal(data, &nodes); err != nil {
|
||||
return err
|
||||
} else if len(nodes) != childrenCount {
|
||||
return errors.New("invalid length of branch node")
|
||||
}
|
||||
|
||||
b := NewBranchNode()
|
||||
for i := range b.Children {
|
||||
b.Children[i] = nodes[i].Node
|
||||
}
|
||||
n.Node = b
|
||||
return nil
|
||||
}
|
||||
|
||||
switch len(m) {
|
||||
case 0:
|
||||
n.Node = new(HashNode)
|
||||
case 1:
|
||||
if v, ok := m["hash"]; ok {
|
||||
var h util.Uint256
|
||||
if err := json.Unmarshal(v, &h); err != nil {
|
||||
return err
|
||||
}
|
||||
n.Node = NewHashNode(h)
|
||||
} else if v, ok = m["value"]; ok {
|
||||
b, err := unmarshalHex(v)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(b) > MaxValueLength {
|
||||
return errors.New("leaf value is too big")
|
||||
}
|
||||
n.Node = NewLeafNode(b)
|
||||
} else {
|
||||
return errors.New("invalid field")
|
||||
}
|
||||
case 2:
|
||||
keyRaw, ok1 := m["key"]
|
||||
nextRaw, ok2 := m["next"]
|
||||
if !ok1 || !ok2 {
|
||||
return errors.New("invalid field")
|
||||
}
|
||||
key, err := unmarshalHex(keyRaw)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(key) > MaxKeyLength {
|
||||
return errors.New("extension key is too big")
|
||||
}
|
||||
|
||||
var next NodeObject
|
||||
if err := json.Unmarshal(nextRaw, &next); err != nil {
|
||||
return err
|
||||
}
|
||||
n.Node = NewExtensionNode(key, next.Node)
|
||||
default:
|
||||
return errors.New("0, 1 or 2 fields expected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalHex(data json.RawMessage) ([]byte, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hex.DecodeString(s)
|
||||
}
|
156
pkg/core/mpt/node_test.go
Normal file
156
pkg/core/mpt/node_test.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
t.Run("IO", func(t *testing.T) {
|
||||
bs, err := testserdes.EncodeBinary(expected)
|
||||
require.NoError(t, err)
|
||||
err = testserdes.DecodeBinary(bs, actual)
|
||||
if !ok {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected.Type(), actual.Type())
|
||||
require.Equal(t, expected.Hash(), actual.Hash())
|
||||
})
|
||||
t.Run("JSON", func(t *testing.T) {
|
||||
bs, err := json.Marshal(expected)
|
||||
require.NoError(t, err)
|
||||
err = json.Unmarshal(bs, actual)
|
||||
if !ok {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected.Type(), actual.Type())
|
||||
require.Equal(t, expected.Hash(), actual.Hash())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNode_Serializable(t *testing.T) {
|
||||
t.Run("Leaf", func(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
l := NewLeafNode(random.Bytes(123))
|
||||
t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode)))
|
||||
t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject)))
|
||||
})
|
||||
t.Run("BigValue", getTestFuncEncode(false,
|
||||
NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode)))
|
||||
})
|
||||
|
||||
t.Run("Extension", func(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10)))
|
||||
t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode)))
|
||||
t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject)))
|
||||
})
|
||||
t.Run("BigKey", getTestFuncEncode(false,
|
||||
NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode)))
|
||||
})
|
||||
|
||||
t.Run("Branch", func(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
b.Children[0] = NewLeafNode(random.Bytes(10))
|
||||
b.Children[lastChild] = NewHashNode(random.Uint256())
|
||||
t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode)))
|
||||
t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject)))
|
||||
})
|
||||
|
||||
t.Run("Hash", func(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
h := NewHashNode(random.Uint256())
|
||||
t.Run("Raw", getTestFuncEncode(true, h, new(HashNode)))
|
||||
t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject)))
|
||||
})
|
||||
t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes
|
||||
testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode))
|
||||
})
|
||||
t.Run("InvalidSize", func(t *testing.T) {
|
||||
buf := io.NewBufBinWriter()
|
||||
buf.BinWriter.WriteVarBytes(make([]byte, 13))
|
||||
require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode)))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject)))
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198
|
||||
func TestJSONSharp(t *testing.T) {
|
||||
tr := NewTrie(nil, newTestStore())
|
||||
require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11}))
|
||||
require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22}))
|
||||
require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac}))
|
||||
require.NoError(t, tr.Delete([]byte{0xac, 0x11}))
|
||||
require.NoError(t, tr.Delete([]byte{0xac, 0x22}))
|
||||
|
||||
js, err := tr.root.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js))
|
||||
}
|
||||
|
||||
func TestInvalidJSON(t *testing.T) {
|
||||
t.Run("InvalidChildrenCount", func(t *testing.T) {
|
||||
var cs [childrenCount + 1]Node
|
||||
for i := range cs {
|
||||
cs[i] = new(HashNode)
|
||||
}
|
||||
data, err := json.Marshal(cs)
|
||||
require.NoError(t, err)
|
||||
|
||||
var n NodeObject
|
||||
require.Error(t, json.Unmarshal(data, &n))
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)},
|
||||
{"InvalidField1", []byte(`{"next":{}}`)},
|
||||
{"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)},
|
||||
{"InvalidKey", []byte(`{"key":"xy", "next":{}}`)},
|
||||
{"InvalidNext", []byte(`{"key":"01", "next":[]}`)},
|
||||
{"InvalidHash", []byte(`{"hash":"01"}`)},
|
||||
{"InvalidValue", []byte(`{"value":1}`)},
|
||||
{"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
var n NodeObject
|
||||
assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
// C# interoperability test
|
||||
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135
|
||||
func TestRootHash(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
r := NewExtensionNode([]byte{0x0A, 0x0C}, b)
|
||||
|
||||
v1 := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
l1 := NewExtensionNode([]byte{0x01}, v1)
|
||||
b.Children[0] = l1
|
||||
|
||||
v2 := NewLeafNode([]byte{0x22, 0x22})
|
||||
l2 := NewExtensionNode([]byte{0x09}, v2)
|
||||
b.Children[9] = l2
|
||||
|
||||
r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1)
|
||||
require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE())
|
||||
require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE())
|
||||
}
|
74
pkg/core/mpt/proof.go
Normal file
74
pkg/core/mpt/proof.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// GetProof returns a proof that key belongs to t.
|
||||
// Proof consist of serialized nodes occuring on path from the root to the leaf of key.
|
||||
func (t *Trie) GetProof(key []byte) ([][]byte, error) {
|
||||
var proof [][]byte
|
||||
path := toNibbles(key)
|
||||
r, err := t.getProof(t.root, path, &proof)
|
||||
if err != nil {
|
||||
return proof, err
|
||||
}
|
||||
t.root = r
|
||||
return proof, nil
|
||||
}
|
||||
|
||||
func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) {
|
||||
switch n := curr.(type) {
|
||||
case *LeafNode:
|
||||
if len(path) == 0 {
|
||||
*proofs = append(*proofs, copySlice(n.Bytes()))
|
||||
return n, nil
|
||||
}
|
||||
case *BranchNode:
|
||||
*proofs = append(*proofs, copySlice(n.Bytes()))
|
||||
i, path := splitPath(path)
|
||||
r, err := t.getProof(n.Children[i], path, proofs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n.Children[i] = r
|
||||
return n, nil
|
||||
case *ExtensionNode:
|
||||
if bytes.HasPrefix(path, n.key) {
|
||||
*proofs = append(*proofs, copySlice(n.Bytes()))
|
||||
r, err := t.getProof(n.next, path[len(n.key):], proofs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n.next = r
|
||||
return n, nil
|
||||
}
|
||||
case *HashNode:
|
||||
if !n.IsEmpty() {
|
||||
r, err := t.getFromStore(n.Hash())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.getProof(r, path, proofs)
|
||||
}
|
||||
}
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash.
|
||||
// It also returns value for the key.
|
||||
func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) {
|
||||
path := toNibbles(key)
|
||||
tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore()))
|
||||
for i := range proofs {
|
||||
h := hash.DoubleSha256(proofs[i])
|
||||
// no errors in Put to memory store
|
||||
_ = tr.Store.Put(makeStorageKey(h[:]), proofs[i])
|
||||
}
|
||||
_, bs, err := tr.getWithPath(tr.root, path)
|
||||
return bs, err == nil
|
||||
}
|
73
pkg/core/mpt/proof_test.go
Normal file
73
pkg/core/mpt/proof_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newProofTrie(t *testing.T) *Trie {
|
||||
l := NewLeafNode([]byte("somevalue"))
|
||||
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
||||
l2 := NewLeafNode([]byte("invalid"))
|
||||
e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash()))
|
||||
b := NewBranchNode()
|
||||
b.Children[4] = NewHashNode(e.Hash())
|
||||
b.Children[5] = e2
|
||||
|
||||
tr := NewTrie(b, newTestStore())
|
||||
require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1")))
|
||||
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
||||
tr.putToStore(l)
|
||||
tr.putToStore(e)
|
||||
return tr
|
||||
}
|
||||
|
||||
func TestTrie_GetProof(t *testing.T) {
|
||||
tr := newProofTrie(t)
|
||||
|
||||
t.Run("MissingKey", func(t *testing.T) {
|
||||
_, err := tr.GetProof([]byte{0x12})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
_, err := tr.GetProof([]byte{0x12, 0x31})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("MissingHashNode", func(t *testing.T) {
|
||||
_, err := tr.GetProof([]byte{0x55})
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifyProof(t *testing.T) {
|
||||
tr := newProofTrie(t)
|
||||
|
||||
t.Run("Simple", func(t *testing.T) {
|
||||
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, []byte("value2"), v)
|
||||
})
|
||||
|
||||
t.Run("Bad", func(t *testing.T) {
|
||||
_, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof)
|
||||
require.False(t, ok)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("InsideHash", func(t *testing.T) {
|
||||
key := []byte{0x45, 0x67}
|
||||
proof, err := tr.GetProof(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, ok := VerifyProof(tr.root.Hash(), key, proof)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, []byte("somevalue"), v)
|
||||
})
|
||||
}
|
390
pkg/core/mpt/trie.go
Normal file
390
pkg/core/mpt/trie.go
Normal file
|
@ -0,0 +1,390 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// Trie is an MPT trie storing all key-value pairs.
|
||||
type Trie struct {
|
||||
Store *storage.MemCachedStore
|
||||
|
||||
root Node
|
||||
}
|
||||
|
||||
// ErrNotFound is returned when requested trie item is missing.
|
||||
var ErrNotFound = errors.New("item not found")
|
||||
|
||||
// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors
|
||||
// so that all storage errors are processed during `store.Persist()` at the caller.
|
||||
// This also has the benefit, that every `Put` can be considered an atomic operation.
|
||||
func NewTrie(root Node, store *storage.MemCachedStore) *Trie {
|
||||
if root == nil {
|
||||
root = new(HashNode)
|
||||
}
|
||||
|
||||
return &Trie{
|
||||
Store: store,
|
||||
root: root,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns value for the provided key in t.
|
||||
func (t *Trie) Get(key []byte) ([]byte, error) {
|
||||
path := toNibbles(key)
|
||||
r, bs, err := t.getWithPath(t.root, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.root = r
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
// getWithPath returns value the provided path in a subtrie rooting in curr.
|
||||
// It also returns a current node with all hash nodes along the path
|
||||
// replaced to their "unhashed" counterparts.
|
||||
func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) {
|
||||
switch n := curr.(type) {
|
||||
case *LeafNode:
|
||||
if len(path) == 0 {
|
||||
return curr, copySlice(n.value), nil
|
||||
}
|
||||
case *BranchNode:
|
||||
i, path := splitPath(path)
|
||||
r, bs, err := t.getWithPath(n.Children[i], path)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
n.Children[i] = r
|
||||
return n, bs, nil
|
||||
case *HashNode:
|
||||
if !n.IsEmpty() {
|
||||
if r, err := t.getFromStore(n.hash); err == nil {
|
||||
return t.getWithPath(r, path)
|
||||
}
|
||||
}
|
||||
case *ExtensionNode:
|
||||
if bytes.HasPrefix(path, n.key) {
|
||||
r, bs, err := t.getWithPath(n.next, path[len(n.key):])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
n.next = r
|
||||
return curr, bs, err
|
||||
}
|
||||
default:
|
||||
panic("invalid MPT node type")
|
||||
}
|
||||
return curr, nil, ErrNotFound
|
||||
}
|
||||
|
||||
// Put puts key-value pair in t.
|
||||
func (t *Trie) Put(key, value []byte) error {
|
||||
if len(key) > MaxKeyLength {
|
||||
return errors.New("key is too big")
|
||||
} else if len(value) > MaxValueLength {
|
||||
return errors.New("value is too big")
|
||||
}
|
||||
if len(value) == 0 {
|
||||
return t.Delete(key)
|
||||
}
|
||||
path := toNibbles(key)
|
||||
n := NewLeafNode(value)
|
||||
r, err := t.putIntoNode(t.root, path, n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.root = r
|
||||
return nil
|
||||
}
|
||||
|
||||
// putIntoLeaf puts val to trie if current node is a Leaf.
|
||||
// It returns Node if curr needs to be replaced and error if any.
|
||||
func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
|
||||
v := val.(*LeafNode)
|
||||
if len(path) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
b := NewBranchNode()
|
||||
b.Children[path[0]] = newSubTrie(path[1:], v)
|
||||
b.Children[lastChild] = curr
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// putIntoBranch puts val to trie if current node is a Branch.
|
||||
// It returns Node if curr needs to be replaced and error if any.
|
||||
func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
|
||||
i, path := splitPath(path)
|
||||
r, err := t.putIntoNode(curr.Children[i], path, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curr.Children[i] = r
|
||||
curr.invalidateCache()
|
||||
return curr, nil
|
||||
}
|
||||
|
||||
// putIntoExtension puts val to trie if current node is an Extension.
|
||||
// It returns Node if curr needs to be replaced and error if any.
|
||||
func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
|
||||
if bytes.HasPrefix(path, curr.key) {
|
||||
r, err := t.putIntoNode(curr.next, path[len(curr.key):], val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curr.next = r
|
||||
curr.invalidateCache()
|
||||
return curr, nil
|
||||
}
|
||||
|
||||
pref := lcp(curr.key, path)
|
||||
lp := len(pref)
|
||||
keyTail := curr.key[lp:]
|
||||
pathTail := path[lp:]
|
||||
|
||||
s1 := newSubTrie(keyTail[1:], curr.next)
|
||||
b := NewBranchNode()
|
||||
b.Children[keyTail[0]] = s1
|
||||
|
||||
i, pathTail := splitPath(pathTail)
|
||||
s2 := newSubTrie(pathTail, val)
|
||||
b.Children[i] = s2
|
||||
|
||||
if lp > 0 {
|
||||
return NewExtensionNode(copySlice(pref), b), nil
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// putIntoHash puts val to trie if current node is a HashNode.
|
||||
// It returns Node if curr needs to be replaced and error if any.
|
||||
func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
|
||||
if curr.IsEmpty() {
|
||||
return newSubTrie(path, val), nil
|
||||
}
|
||||
|
||||
result, err := t.getFromStore(curr.hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.putIntoNode(result, path, val)
|
||||
}
|
||||
|
||||
// newSubTrie create new trie containing node at provided path.
|
||||
func newSubTrie(path []byte, val Node) Node {
|
||||
if len(path) == 0 {
|
||||
return val
|
||||
}
|
||||
return NewExtensionNode(path, val)
|
||||
}
|
||||
|
||||
func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
|
||||
switch n := curr.(type) {
|
||||
case *LeafNode:
|
||||
return t.putIntoLeaf(n, path, val)
|
||||
case *BranchNode:
|
||||
return t.putIntoBranch(n, path, val)
|
||||
case *ExtensionNode:
|
||||
return t.putIntoExtension(n, path, val)
|
||||
case *HashNode:
|
||||
return t.putIntoHash(n, path, val)
|
||||
default:
|
||||
panic("invalid MPT node type")
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes key from trie.
|
||||
// It returns no error on missing key.
|
||||
func (t *Trie) Delete(key []byte) error {
|
||||
path := toNibbles(key)
|
||||
r, err := t.deleteFromNode(t.root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.root = r
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
|
||||
i, path := splitPath(path)
|
||||
r, err := t.deleteFromNode(b.Children[i], path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Children[i] = r
|
||||
b.invalidateCache()
|
||||
var count, index int
|
||||
for i := range b.Children {
|
||||
h, ok := b.Children[i].(*HashNode)
|
||||
if !ok || !h.IsEmpty() {
|
||||
index = i
|
||||
count++
|
||||
}
|
||||
}
|
||||
// count is >= 1 because branch node had at least 2 children before deletion.
|
||||
if count > 1 {
|
||||
return b, nil
|
||||
}
|
||||
c := b.Children[index]
|
||||
if index == lastChild {
|
||||
return c, nil
|
||||
}
|
||||
if h, ok := c.(*HashNode); ok {
|
||||
c, err = t.getFromStore(h.Hash())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if e, ok := c.(*ExtensionNode); ok {
|
||||
e.key = append([]byte{byte(index)}, e.key...)
|
||||
e.invalidateCache()
|
||||
return e, nil
|
||||
}
|
||||
|
||||
return NewExtensionNode([]byte{byte(index)}, c), nil
|
||||
}
|
||||
|
||||
func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) {
|
||||
if !bytes.HasPrefix(path, n.key) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
r, err := t.deleteFromNode(n.next, path[len(n.key):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch nxt := r.(type) {
|
||||
case *ExtensionNode:
|
||||
n.key = append(n.key, nxt.key...)
|
||||
n.next = nxt.next
|
||||
case *HashNode:
|
||||
if nxt.IsEmpty() {
|
||||
return nxt, nil
|
||||
}
|
||||
default:
|
||||
n.next = r
|
||||
}
|
||||
n.invalidateCache()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
|
||||
switch n := curr.(type) {
|
||||
case *LeafNode:
|
||||
if len(path) == 0 {
|
||||
return new(HashNode), nil
|
||||
}
|
||||
return nil, ErrNotFound
|
||||
case *BranchNode:
|
||||
return t.deleteFromBranch(n, path)
|
||||
case *ExtensionNode:
|
||||
return t.deleteFromExtension(n, path)
|
||||
case *HashNode:
|
||||
if n.IsEmpty() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
newNode, err := t.getFromStore(n.Hash())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.deleteFromNode(newNode, path)
|
||||
default:
|
||||
panic("invalid MPT node type")
|
||||
}
|
||||
}
|
||||
|
||||
// StateRoot returns root hash of t.
|
||||
func (t *Trie) StateRoot() util.Uint256 {
|
||||
if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() {
|
||||
return util.Uint256{}
|
||||
}
|
||||
return t.root.Hash()
|
||||
}
|
||||
|
||||
func makeStorageKey(mptKey []byte) []byte {
|
||||
return append([]byte{byte(storage.DataMPT)}, mptKey...)
|
||||
}
|
||||
|
||||
// Flush puts every node in the trie except Hash ones to the storage.
|
||||
// Because we care only about block-level changes, there is no need to put every
|
||||
// new node to storage. Normally, flush should be called with every StateRoot persist, i.e.
|
||||
// after every block.
|
||||
func (t *Trie) Flush() {
|
||||
t.flush(t.root)
|
||||
}
|
||||
|
||||
func (t *Trie) flush(node Node) {
|
||||
if node.IsFlushed() {
|
||||
return
|
||||
}
|
||||
switch n := node.(type) {
|
||||
case *BranchNode:
|
||||
for i := range n.Children {
|
||||
t.flush(n.Children[i])
|
||||
}
|
||||
case *ExtensionNode:
|
||||
t.flush(n.next)
|
||||
case *HashNode:
|
||||
return
|
||||
}
|
||||
t.putToStore(node)
|
||||
}
|
||||
|
||||
func (t *Trie) putToStore(n Node) {
|
||||
if n.Type() == HashT {
|
||||
panic("can't put hash node in trie")
|
||||
}
|
||||
_ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors
|
||||
n.SetFlushed()
|
||||
}
|
||||
|
||||
func (t *Trie) getFromStore(h util.Uint256) (Node, error) {
|
||||
data, err := t.Store.Get(makeStorageKey(h.BytesBE()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var n NodeObject
|
||||
r := io.NewBinReaderFromBuf(data)
|
||||
n.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
return nil, r.Err
|
||||
}
|
||||
return n.Node, nil
|
||||
}
|
||||
|
||||
// Collapse compresses all nodes at depth n to the hash nodes.
|
||||
// Note: this function does not perform any kind of storage flushing so
|
||||
// `Flush()` should be called explicitly before invoking function.
|
||||
func (t *Trie) Collapse(depth int) {
|
||||
if depth < 0 {
|
||||
panic("negative depth")
|
||||
}
|
||||
t.root = collapse(depth, t.root)
|
||||
}
|
||||
|
||||
func collapse(depth int, node Node) Node {
|
||||
if _, ok := node.(*HashNode); ok {
|
||||
return node
|
||||
} else if depth == 0 {
|
||||
return NewHashNode(node.Hash())
|
||||
}
|
||||
|
||||
switch n := node.(type) {
|
||||
case *BranchNode:
|
||||
for i := range n.Children {
|
||||
n.Children[i] = collapse(depth-1, n.Children[i])
|
||||
}
|
||||
case *ExtensionNode:
|
||||
n.next = collapse(depth-1, n.next)
|
||||
case *LeafNode:
|
||||
case *HashNode:
|
||||
default:
|
||||
panic("invalid MPT node type")
|
||||
}
|
||||
return node
|
||||
}
|
446
pkg/core/mpt/trie_test.go
Normal file
446
pkg/core/mpt/trie_test.go
Normal file
|
@ -0,0 +1,446 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestStore() *storage.MemCachedStore {
|
||||
return storage.NewMemCachedStore(storage.NewMemoryStore())
|
||||
}
|
||||
|
||||
func newTestTrie(t *testing.T) *Trie {
|
||||
b := NewBranchNode()
|
||||
|
||||
l1 := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
b.Children[0] = NewExtensionNode([]byte{0x01}, l1)
|
||||
|
||||
l2 := NewLeafNode([]byte{0x22, 0x22})
|
||||
b.Children[9] = NewExtensionNode([]byte{0x09}, l2)
|
||||
|
||||
v := NewLeafNode([]byte("hello"))
|
||||
h := NewHashNode(v.Hash())
|
||||
b.Children[10] = NewExtensionNode([]byte{0x0e}, h)
|
||||
|
||||
e := NewExtensionNode(toNibbles([]byte{0xAC}), b)
|
||||
tr := NewTrie(e, newTestStore())
|
||||
|
||||
tr.putToStore(e)
|
||||
tr.putToStore(b)
|
||||
tr.putToStore(l1)
|
||||
tr.putToStore(l2)
|
||||
tr.putToStore(v)
|
||||
tr.putToStore(b.Children[0])
|
||||
tr.putToStore(b.Children[9])
|
||||
tr.putToStore(b.Children[10])
|
||||
|
||||
return tr
|
||||
}
|
||||
|
||||
func TestTrie_PutIntoBranchNode(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
l := NewLeafNode([]byte{0x8})
|
||||
b.Children[0x7] = NewHashNode(l.Hash())
|
||||
b.Children[0x8] = NewHashNode(random.Uint256())
|
||||
tr := NewTrie(b, newTestStore())
|
||||
|
||||
// next
|
||||
require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34}))
|
||||
tr.testHas(t, []byte{}, []byte{0x12, 0x34})
|
||||
|
||||
// empty hash node child
|
||||
require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56}))
|
||||
tr.testHas(t, []byte{0x66}, []byte{0x56})
|
||||
require.True(t, isValid(tr.root))
|
||||
|
||||
// missing hash
|
||||
require.Error(t, tr.Put([]byte{0x70}, []byte{0x42}))
|
||||
require.True(t, isValid(tr.root))
|
||||
|
||||
// hash is in store
|
||||
tr.putToStore(l)
|
||||
require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42}))
|
||||
require.True(t, isValid(tr.root))
|
||||
}
|
||||
|
||||
func TestTrie_PutIntoExtensionNode(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0x11})
|
||||
key := []byte{0x12}
|
||||
e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash()))
|
||||
tr := NewTrie(e, newTestStore())
|
||||
|
||||
// missing hash
|
||||
require.Error(t, tr.Put(key, []byte{0x42}))
|
||||
|
||||
tr.putToStore(l)
|
||||
require.NoError(t, tr.Put(key, []byte{0x42}))
|
||||
tr.testHas(t, key, []byte{0x42})
|
||||
require.True(t, isValid(tr.root))
|
||||
}
|
||||
|
||||
func TestTrie_PutIntoHashNode(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
l := NewLeafNode(random.Bytes(5))
|
||||
e := NewExtensionNode([]byte{0x02}, l)
|
||||
b.Children[1] = NewHashNode(e.Hash())
|
||||
b.Children[9] = NewHashNode(random.Uint256())
|
||||
tr := NewTrie(b, newTestStore())
|
||||
|
||||
tr.putToStore(e)
|
||||
|
||||
t.Run("MissingLeafHash", func(t *testing.T) {
|
||||
_, err := tr.Get([]byte{0x12})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
tr.putToStore(l)
|
||||
|
||||
val := random.Bytes(3)
|
||||
require.NoError(t, tr.Put([]byte{0x12, 0x34}, val))
|
||||
tr.testHas(t, []byte{0x12, 0x34}, val)
|
||||
tr.testHas(t, []byte{0x12}, l.value)
|
||||
require.True(t, isValid(tr.root))
|
||||
}
|
||||
|
||||
func TestTrie_Put(t *testing.T) {
|
||||
trExp := newTestTrie(t)
|
||||
|
||||
trAct := NewTrie(nil, newTestStore())
|
||||
require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD}))
|
||||
require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22}))
|
||||
require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello")))
|
||||
|
||||
// Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie.
|
||||
require.Equal(t, trExp.root.Hash(), trAct.root.Hash())
|
||||
require.True(t, isValid(trAct.root))
|
||||
}
|
||||
|
||||
func TestTrie_PutInvalid(t *testing.T) {
|
||||
tr := NewTrie(nil, newTestStore())
|
||||
key, value := []byte("key"), []byte("value")
|
||||
|
||||
// big key
|
||||
require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value))
|
||||
|
||||
// big value
|
||||
require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1)))
|
||||
|
||||
// this is ok though
|
||||
require.NoError(t, tr.Put(key, value))
|
||||
tr.testHas(t, key, value)
|
||||
}
|
||||
|
||||
func TestTrie_BigPut(t *testing.T) {
|
||||
tr := NewTrie(nil, newTestStore())
|
||||
items := []struct{ k, v string }{
|
||||
{"item with long key", "value1"},
|
||||
{"item with matching prefix", "value2"},
|
||||
{"another prefix", "value3"},
|
||||
{"another prefix 2", "value4"},
|
||||
{"another ", "value5"},
|
||||
}
|
||||
|
||||
for i := range items {
|
||||
require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v)))
|
||||
}
|
||||
|
||||
for i := range items {
|
||||
tr.testHas(t, []byte(items[i].k), []byte(items[i].v))
|
||||
}
|
||||
|
||||
t.Run("Rewrite", func(t *testing.T) {
|
||||
k, v := []byte(items[0].k), []byte{0x01, 0x23}
|
||||
require.NoError(t, tr.Put(k, v))
|
||||
tr.testHas(t, k, v)
|
||||
})
|
||||
|
||||
t.Run("Remove", func(t *testing.T) {
|
||||
k := []byte(items[1].k)
|
||||
require.NoError(t, tr.Put(k, []byte{}))
|
||||
tr.testHas(t, k, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func (tr *Trie) testHas(t *testing.T, key, value []byte) {
|
||||
v, err := tr.Get(key)
|
||||
if value == nil {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, value, v)
|
||||
}
|
||||
|
||||
// isValid checks for 3 invariants:
|
||||
// - BranchNode contains > 1 children
|
||||
// - ExtensionNode do not contain another extension node
|
||||
// - ExtensionNode do not have nil key
|
||||
// It is used only during testing to catch possible bugs.
|
||||
func isValid(curr Node) bool {
|
||||
switch n := curr.(type) {
|
||||
case *BranchNode:
|
||||
var count int
|
||||
for i := range n.Children {
|
||||
if !isValid(n.Children[i]) {
|
||||
return false
|
||||
}
|
||||
hn, ok := n.Children[i].(*HashNode)
|
||||
if !ok || !hn.IsEmpty() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count > 1
|
||||
case *ExtensionNode:
|
||||
_, ok := n.next.(*ExtensionNode)
|
||||
return len(n.key) != 0 && !ok
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrie_Get(t *testing.T) {
|
||||
t.Run("HashNode", func(t *testing.T) {
|
||||
tr := newTestTrie(t)
|
||||
tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello"))
|
||||
})
|
||||
t.Run("UnfoldRoot", func(t *testing.T) {
|
||||
tr := newTestTrie(t)
|
||||
single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store)
|
||||
single.testHas(t, []byte{0xAC}, nil)
|
||||
single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD})
|
||||
single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22})
|
||||
single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrie_Flush(t *testing.T) {
|
||||
pairs := map[string][]byte{
|
||||
"": []byte("value0"),
|
||||
"key1": []byte("value1"),
|
||||
"key2": []byte("value2"),
|
||||
}
|
||||
|
||||
tr := NewTrie(nil, newTestStore())
|
||||
for k, v := range pairs {
|
||||
require.NoError(t, tr.Put([]byte(k), v))
|
||||
}
|
||||
|
||||
tr.Flush()
|
||||
tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store)
|
||||
for k, v := range pairs {
|
||||
actual, err := tr.Get([]byte(k))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, v, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrie_Delete(t *testing.T) {
|
||||
t.Run("Hash", func(t *testing.T) {
|
||||
t.Run("FromStore", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0x12})
|
||||
tr := NewTrie(NewHashNode(l.Hash()), newTestStore())
|
||||
t.Run("NotInStore", func(t *testing.T) {
|
||||
require.Error(t, tr.Delete([]byte{}))
|
||||
})
|
||||
|
||||
tr.putToStore(l)
|
||||
tr.testHas(t, []byte{}, []byte{0x12})
|
||||
require.NoError(t, tr.Delete([]byte{}))
|
||||
tr.testHas(t, []byte{}, nil)
|
||||
})
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
tr := NewTrie(nil, newTestStore())
|
||||
require.Error(t, tr.Delete([]byte{}))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Leaf", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0x12, 0x34})
|
||||
tr := NewTrie(l, newTestStore())
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
require.Error(t, tr.Delete([]byte{0x12}))
|
||||
tr.testHas(t, []byte{}, []byte{0x12, 0x34})
|
||||
})
|
||||
require.NoError(t, tr.Delete([]byte{}))
|
||||
tr.testHas(t, []byte{}, nil)
|
||||
})
|
||||
|
||||
t.Run("Extension", func(t *testing.T) {
|
||||
t.Run("SingleKey", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0x12, 0x34})
|
||||
e := NewExtensionNode([]byte{0x0A, 0x0B}, l)
|
||||
tr := NewTrie(e, newTestStore())
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
require.Error(t, tr.Delete([]byte{}))
|
||||
tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34})
|
||||
})
|
||||
|
||||
require.NoError(t, tr.Delete([]byte{0xAB}))
|
||||
require.True(t, tr.root.(*HashNode).IsEmpty())
|
||||
})
|
||||
|
||||
t.Run("MultipleKeys", func(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34}))
|
||||
b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78}))
|
||||
e := NewExtensionNode([]byte{0x01, 0x02}, b)
|
||||
tr := NewTrie(e, newTestStore())
|
||||
|
||||
h := e.Hash()
|
||||
require.NoError(t, tr.Delete([]byte{0x12, 0x01}))
|
||||
tr.testHas(t, []byte{0x12, 0x01}, nil)
|
||||
tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78})
|
||||
|
||||
require.NotEqual(t, h, tr.root.Hash())
|
||||
require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key)
|
||||
require.IsType(t, (*LeafNode)(nil), e.next)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Branch", func(t *testing.T) {
|
||||
t.Run("3 Children", func(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
b.Children[lastChild] = NewLeafNode([]byte{0x12})
|
||||
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34}))
|
||||
b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56}))
|
||||
tr := NewTrie(b, newTestStore())
|
||||
require.NoError(t, tr.Delete([]byte{0x16}))
|
||||
tr.testHas(t, []byte{}, []byte{0x12})
|
||||
tr.testHas(t, []byte{0x01}, []byte{0x34})
|
||||
tr.testHas(t, []byte{0x16}, nil)
|
||||
})
|
||||
t.Run("2 Children", func(t *testing.T) {
|
||||
newt := func(t *testing.T) *Trie {
|
||||
b := NewBranchNode()
|
||||
b.Children[lastChild] = NewLeafNode([]byte{0x12})
|
||||
l := NewLeafNode([]byte{0x34})
|
||||
e := NewExtensionNode([]byte{0x06}, l)
|
||||
b.Children[5] = NewHashNode(e.Hash())
|
||||
tr := NewTrie(b, newTestStore())
|
||||
tr.putToStore(l)
|
||||
tr.putToStore(e)
|
||||
return tr
|
||||
}
|
||||
|
||||
t.Run("DeleteLast", func(t *testing.T) {
|
||||
t.Run("MergeExtension", func(t *testing.T) {
|
||||
tr := newt(t)
|
||||
require.NoError(t, tr.Delete([]byte{}))
|
||||
tr.testHas(t, []byte{}, nil)
|
||||
tr.testHas(t, []byte{0x56}, []byte{0x34})
|
||||
require.IsType(t, (*ExtensionNode)(nil), tr.root)
|
||||
})
|
||||
|
||||
t.Run("LeaveLeaf", func(t *testing.T) {
|
||||
c := NewBranchNode()
|
||||
c.Children[5] = NewLeafNode([]byte{0x05})
|
||||
c.Children[6] = NewLeafNode([]byte{0x06})
|
||||
|
||||
b := NewBranchNode()
|
||||
b.Children[lastChild] = NewLeafNode([]byte{0x12})
|
||||
b.Children[5] = c
|
||||
tr := NewTrie(b, newTestStore())
|
||||
|
||||
require.NoError(t, tr.Delete([]byte{}))
|
||||
tr.testHas(t, []byte{}, nil)
|
||||
tr.testHas(t, []byte{0x55}, []byte{0x05})
|
||||
tr.testHas(t, []byte{0x56}, []byte{0x06})
|
||||
require.IsType(t, (*ExtensionNode)(nil), tr.root)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DeleteMiddle", func(t *testing.T) {
|
||||
tr := newt(t)
|
||||
require.NoError(t, tr.Delete([]byte{0x56}))
|
||||
tr.testHas(t, []byte{}, []byte{0x12})
|
||||
tr.testHas(t, []byte{0x56}, nil)
|
||||
require.IsType(t, (*LeafNode)(nil), tr.root)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrie_PanicInvalidRoot(t *testing.T) {
|
||||
tr := &Trie{Store: newTestStore()}
|
||||
require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) })
|
||||
require.Panics(t, func() { _, _ = tr.Get([]byte{1}) })
|
||||
require.Panics(t, func() { _ = tr.Delete([]byte{1}) })
|
||||
}
|
||||
|
||||
func TestTrie_Collapse(t *testing.T) {
|
||||
t.Run("PanicNegative", func(t *testing.T) {
|
||||
tr := newTestTrie(t)
|
||||
require.Panics(t, func() { tr.Collapse(-1) })
|
||||
})
|
||||
t.Run("Depth=0", func(t *testing.T) {
|
||||
tr := newTestTrie(t)
|
||||
h := tr.root.Hash()
|
||||
|
||||
_, ok := tr.root.(*HashNode)
|
||||
require.False(t, ok)
|
||||
|
||||
tr.Collapse(0)
|
||||
_, ok = tr.root.(*HashNode)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, tr.root.Hash())
|
||||
})
|
||||
t.Run("Branch,Depth=1", func(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
e := NewExtensionNode([]byte{0x01}, NewLeafNode([]byte("value1")))
|
||||
he := e.Hash()
|
||||
b.Children[0] = e
|
||||
hb := b.Hash()
|
||||
|
||||
tr := NewTrie(b, newTestStore())
|
||||
tr.Collapse(1)
|
||||
|
||||
newb, ok := tr.root.(*BranchNode)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, hb, newb.Hash())
|
||||
require.IsType(t, (*HashNode)(nil), b.Children[0])
|
||||
require.Equal(t, he, b.Children[0].Hash())
|
||||
})
|
||||
t.Run("Extension,Depth=1", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte("value"))
|
||||
hl := l.Hash()
|
||||
e := NewExtensionNode([]byte{0x01}, l)
|
||||
h := e.Hash()
|
||||
tr := NewTrie(e, newTestStore())
|
||||
tr.Collapse(1)
|
||||
|
||||
newe, ok := tr.root.(*ExtensionNode)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, h, newe.Hash())
|
||||
require.IsType(t, (*HashNode)(nil), newe.next)
|
||||
require.Equal(t, hl, newe.next.Hash())
|
||||
})
|
||||
t.Run("Leaf", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte("value"))
|
||||
tr := NewTrie(l, newTestStore())
|
||||
tr.Collapse(10)
|
||||
require.Equal(t, NewLeafNode([]byte("value")), tr.root)
|
||||
})
|
||||
t.Run("Hash", func(t *testing.T) {
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
tr := NewTrie(new(HashNode), newTestStore())
|
||||
require.NotPanics(t, func() { tr.Collapse(1) })
|
||||
hn, ok := tr.root.(*HashNode)
|
||||
require.True(t, ok)
|
||||
require.True(t, hn.IsEmpty())
|
||||
})
|
||||
|
||||
h := random.Uint256()
|
||||
hn := NewHashNode(h)
|
||||
tr := NewTrie(hn, newTestStore())
|
||||
tr.Collapse(10)
|
||||
|
||||
newRoot, ok := tr.root.(*HashNode)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, NewHashNode(h), newRoot)
|
||||
})
|
||||
}
|
|
@ -30,6 +30,14 @@ var (
|
|||
Namespace: "neogo",
|
||||
},
|
||||
)
|
||||
//stateHeight prometheus metric.
|
||||
stateHeight = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Help: "Current verified state height",
|
||||
Name: "current_state_height",
|
||||
Namespace: "neogo",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -51,3 +59,7 @@ func updateHeaderHeightMetric(hHeight int) {
|
|||
func updateBlockHeightMetric(bHeight uint32) {
|
||||
blockHeight.Set(float64(bHeight))
|
||||
}
|
||||
|
||||
func updateStateHeightMetric(sHeight uint32) {
|
||||
stateHeight.Set(float64(sHeight))
|
||||
}
|
||||
|
|
146
pkg/core/state/mpt_root.go
Normal file
146
pkg/core/state/mpt_root.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// MPTRootBase represents storage state root.
|
||||
type MPTRootBase struct {
|
||||
Version byte `json:"version"`
|
||||
Index uint32 `json:"index"`
|
||||
PrevHash util.Uint256 `json:"prehash"`
|
||||
Root util.Uint256 `json:"stateroot"`
|
||||
}
|
||||
|
||||
// MPTRoot represents storage state root together with sign info.
|
||||
type MPTRoot struct {
|
||||
MPTRootBase
|
||||
Witness *transaction.Witness `json:"witness,omitempty"`
|
||||
}
|
||||
|
||||
// MPTRootStateFlag represents verification state of the state root.
|
||||
type MPTRootStateFlag byte
|
||||
|
||||
// Possible verification states of MPTRoot.
|
||||
const (
|
||||
Unverified MPTRootStateFlag = 0x00
|
||||
Verified MPTRootStateFlag = 0x01
|
||||
Invalid MPTRootStateFlag = 0x03
|
||||
)
|
||||
|
||||
// MPTRootState represents state root together with its verification state.
|
||||
type MPTRootState struct {
|
||||
MPTRoot `json:"stateroot"`
|
||||
Flag MPTRootStateFlag `json:"flag"`
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (s *MPTRootState) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteB(byte(s.Flag))
|
||||
s.MPTRoot.EncodeBinary(w)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (s *MPTRootState) DecodeBinary(r *io.BinReader) {
|
||||
s.Flag = MPTRootStateFlag(r.ReadB())
|
||||
s.MPTRoot.DecodeBinary(r)
|
||||
}
|
||||
|
||||
// GetSignedPart returns part of MPTRootBase which needs to be signed.
|
||||
func (s *MPTRootBase) GetSignedPart() []byte {
|
||||
buf := io.NewBufBinWriter()
|
||||
s.EncodeBinary(buf.BinWriter)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// Equals checks if s == other.
|
||||
func (s *MPTRootBase) Equals(other *MPTRootBase) bool {
|
||||
return s.Version == other.Version && s.Index == other.Index &&
|
||||
s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root)
|
||||
}
|
||||
|
||||
// Hash returns hash of s.
|
||||
func (s *MPTRootBase) Hash() util.Uint256 {
|
||||
return hash.DoubleSha256(s.GetSignedPart())
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (s *MPTRootBase) DecodeBinary(r *io.BinReader) {
|
||||
s.Version = r.ReadB()
|
||||
s.Index = r.ReadU32LE()
|
||||
s.PrevHash.DecodeBinary(r)
|
||||
s.Root.DecodeBinary(r)
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteB(s.Version)
|
||||
w.WriteU32LE(s.Index)
|
||||
s.PrevHash.EncodeBinary(w)
|
||||
s.Root.EncodeBinary(w)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (s *MPTRoot) DecodeBinary(r *io.BinReader) {
|
||||
s.MPTRootBase.DecodeBinary(r)
|
||||
|
||||
var ws []transaction.Witness
|
||||
r.ReadArray(&ws, 1)
|
||||
if len(ws) == 1 {
|
||||
s.Witness = &ws[0]
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (s *MPTRoot) EncodeBinary(w *io.BinWriter) {
|
||||
s.MPTRootBase.EncodeBinary(w)
|
||||
if s.Witness == nil {
|
||||
w.WriteVarUint(0)
|
||||
} else {
|
||||
w.WriteArray([]*transaction.Witness{s.Witness})
|
||||
}
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (f MPTRootStateFlag) String() string {
|
||||
switch f {
|
||||
case Unverified:
|
||||
return "Unverified"
|
||||
case Verified:
|
||||
return "Verified"
|
||||
case Invalid:
|
||||
return "Invalid"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (f MPTRootStateFlag) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + f.String() + `"`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (f *MPTRootStateFlag) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
switch s {
|
||||
case "Unverified":
|
||||
*f = Unverified
|
||||
case "Verified":
|
||||
*f = Verified
|
||||
case "Invalid":
|
||||
*f = Invalid
|
||||
default:
|
||||
return errors.New("unknown flag")
|
||||
}
|
||||
return nil
|
||||
}
|
100
pkg/core/state/mpt_root_test.go
Normal file
100
pkg/core/state/mpt_root_test.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testStateRoot() *MPTRoot {
|
||||
return &MPTRoot{
|
||||
MPTRootBase: MPTRootBase{
|
||||
Version: byte(rand.Uint32()),
|
||||
Index: rand.Uint32(),
|
||||
PrevHash: random.Uint256(),
|
||||
Root: random.Uint256(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateRoot_Serializable(t *testing.T) {
|
||||
r := testStateRoot()
|
||||
testserdes.EncodeDecodeBinary(t, r, new(MPTRoot))
|
||||
|
||||
t.Run("WithWitness", func(t *testing.T) {
|
||||
r.Witness = &transaction.Witness{
|
||||
InvocationScript: random.Bytes(10),
|
||||
VerificationScript: random.Bytes(11),
|
||||
}
|
||||
testserdes.EncodeDecodeBinary(t, r, new(MPTRoot))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStateRootEquals(t *testing.T) {
|
||||
r1 := testStateRoot()
|
||||
r2 := *r1
|
||||
require.True(t, r1.Equals(&r2.MPTRootBase))
|
||||
|
||||
r2.MPTRootBase.Index++
|
||||
require.False(t, r1.Equals(&r2.MPTRootBase))
|
||||
}
|
||||
|
||||
func TestMPTRootState_Serializable(t *testing.T) {
|
||||
rs := &MPTRootState{
|
||||
MPTRoot: *testStateRoot(),
|
||||
Flag: 0x04,
|
||||
}
|
||||
rs.MPTRoot.Witness = &transaction.Witness{
|
||||
InvocationScript: random.Bytes(10),
|
||||
VerificationScript: random.Bytes(11),
|
||||
}
|
||||
testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState))
|
||||
}
|
||||
|
||||
func TestMPTRootStateUnverifiedByDefault(t *testing.T) {
|
||||
var r MPTRootState
|
||||
require.Equal(t, Unverified, r.Flag)
|
||||
}
|
||||
|
||||
func TestMPTRoot_MarshalJSON(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
r := testStateRoot()
|
||||
rs := &MPTRootState{
|
||||
MPTRoot: *r,
|
||||
Flag: Verified,
|
||||
}
|
||||
testserdes.MarshalUnmarshalJSON(t, rs, new(MPTRootState))
|
||||
})
|
||||
|
||||
t.Run("Compatibility", func(t *testing.T) {
|
||||
js := []byte(`{
|
||||
"flag": "Unverified",
|
||||
"stateroot": {
|
||||
"version": 1,
|
||||
"index": 3000000,
|
||||
"prehash": "0x4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90",
|
||||
"stateroot": "0xb2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3"
|
||||
}}`)
|
||||
|
||||
rs := new(MPTRootState)
|
||||
require.NoError(t, json.Unmarshal(js, &rs))
|
||||
|
||||
require.EqualValues(t, 1, rs.Version)
|
||||
require.EqualValues(t, 3000000, rs.Index)
|
||||
require.Nil(t, rs.Witness)
|
||||
|
||||
u, err := util.Uint256DecodeStringLE("4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, u, rs.PrevHash)
|
||||
|
||||
u, err = util.Uint256DecodeStringLE("b2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, u, rs.Root)
|
||||
})
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
const (
|
||||
DataBlock KeyPrefix = 0x01
|
||||
DataTransaction KeyPrefix = 0x02
|
||||
DataMPT KeyPrefix = 0x03
|
||||
STAccount KeyPrefix = 0x40
|
||||
STCoin KeyPrefix = 0x44
|
||||
STSpentCoin KeyPrefix = 0x45
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package keys
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/x509"
|
||||
|
@ -10,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
|
@ -18,6 +18,9 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// coordLen is the number of bytes in serialized X or Y coordinate.
|
||||
const coordLen = 32
|
||||
|
||||
// PublicKeys is a list of public keys.
|
||||
type PublicKeys []*PublicKey
|
||||
|
||||
|
@ -94,23 +97,49 @@ func NewPublicKeyFromString(s string) (*PublicKey, error) {
|
|||
return pubKey, nil
|
||||
}
|
||||
|
||||
// Bytes returns the byte array representation of the public key.
|
||||
func (p *PublicKey) Bytes() []byte {
|
||||
// getBytes serializes X and Y using compressed or uncompressed format.
|
||||
func (p *PublicKey) getBytes(compressed bool) []byte {
|
||||
if p.IsInfinity() {
|
||||
return []byte{0x00}
|
||||
}
|
||||
|
||||
var (
|
||||
x = p.X.Bytes()
|
||||
paddedX = append(bytes.Repeat([]byte{0x00}, 32-len(x)), x...)
|
||||
prefix = byte(0x03)
|
||||
)
|
||||
|
||||
if p.Y.Bit(0) == 0 {
|
||||
prefix = byte(0x02)
|
||||
var resLen = 1 + coordLen
|
||||
if !compressed {
|
||||
resLen += coordLen
|
||||
}
|
||||
var res = make([]byte, resLen)
|
||||
var prefix byte
|
||||
|
||||
return append([]byte{prefix}, paddedX...)
|
||||
xBytes := p.X.Bytes()
|
||||
copy(res[1+coordLen-len(xBytes):], xBytes)
|
||||
if compressed {
|
||||
if p.Y.Bit(0) == 0 {
|
||||
prefix = 0x02
|
||||
} else {
|
||||
prefix = 0x03
|
||||
}
|
||||
} else {
|
||||
prefix = 0x04
|
||||
yBytes := p.Y.Bytes()
|
||||
copy(res[1+coordLen+coordLen-len(yBytes):], yBytes)
|
||||
|
||||
}
|
||||
res[0] = prefix
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// Bytes returns byte array representation of the public key in compressed
|
||||
// form (33 bytes with 0x02 or 0x03 prefix, except infinity which is always 0).
|
||||
func (p *PublicKey) Bytes() []byte {
|
||||
return p.getBytes(true)
|
||||
}
|
||||
|
||||
// UncompressedBytes returns byte array representation of the public key in
|
||||
// uncompressed form (65 bytes with 0x04 prefix, except infinity which is
|
||||
// always 0).
|
||||
func (p *PublicKey) UncompressedBytes() []byte {
|
||||
return p.getBytes(false)
|
||||
}
|
||||
|
||||
// NewPublicKeyFromASN1 returns a NEO PublicKey from the ASN.1 serialized key.
|
||||
|
@ -134,15 +163,25 @@ func NewPublicKeyFromASN1(data []byte) (*PublicKey, error) {
|
|||
}
|
||||
|
||||
// decodeCompressedY performs decompression of Y coordinate for given X and Y's least significant bit.
|
||||
func decodeCompressedY(x *big.Int, ylsb uint) (*big.Int, error) {
|
||||
c := elliptic.P256()
|
||||
cp := c.Params()
|
||||
three := big.NewInt(3)
|
||||
/* y**2 = x**3 + a*x + b % p */
|
||||
xCubed := new(big.Int).Exp(x, three, cp.P)
|
||||
threeX := new(big.Int).Mul(x, three)
|
||||
threeX.Mod(threeX, cp.P)
|
||||
ySquared := new(big.Int).Sub(xCubed, threeX)
|
||||
// We use here a short-form Weierstrass curve (https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html)
|
||||
// y² = x³ + ax + b. Two types of elliptic curves are supported:
|
||||
// 1. Secp256k1 (Koblitz curve): y² = x³ + b,
|
||||
// 2. Secp256r1 (Random curve): y² = x³ - 3x + b.
|
||||
// To decode compressed curve point we perform the following operation: y = sqrt(x³ + ax + b mod p)
|
||||
// where `p` denotes the order of the underlying curve field
|
||||
func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, error) {
|
||||
var a *big.Int
|
||||
switch curve.(type) {
|
||||
case *btcec.KoblitzCurve:
|
||||
a = big.NewInt(0)
|
||||
default:
|
||||
a = big.NewInt(3)
|
||||
}
|
||||
cp := curve.Params()
|
||||
xCubed := new(big.Int).Exp(x, big.NewInt(3), cp.P)
|
||||
aX := new(big.Int).Mul(x, a)
|
||||
aX.Mod(aX, cp.P)
|
||||
ySquared := new(big.Int).Sub(xCubed, aX)
|
||||
ySquared.Add(ySquared, cp.B)
|
||||
ySquared.Mod(ySquared, cp.P)
|
||||
y := new(big.Int).ModSqrt(ySquared, cp.P)
|
||||
|
@ -196,7 +235,7 @@ func (p *PublicKey) DecodeBinary(r *io.BinReader) {
|
|||
}
|
||||
x = new(big.Int).SetBytes(xbytes)
|
||||
ylsb := uint(prefix & 0x1)
|
||||
y, err = decodeCompressedY(x, ylsb)
|
||||
y, err = decodeCompressedY(x, ylsb, p256)
|
||||
if err != nil {
|
||||
r.Err = err
|
||||
return
|
||||
|
@ -306,3 +345,84 @@ func (p *PublicKey) UnmarshalJSON(data []byte) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// KeyRecover recovers public key from the given signature (r, s) on the given message hash using given elliptic curve.
|
||||
// Algorithm source: SEC 1 Ver 2.0, section 4.1.6, pages 47-48 (https://www.secg.org/sec1-v2.pdf).
|
||||
// Flag isEven denotes Y's least significant bit in decompression algorithm.
|
||||
func KeyRecover(curve elliptic.Curve, r, s *big.Int, messageHash []byte, isEven bool) (PublicKey, error) {
|
||||
var (
|
||||
res PublicKey
|
||||
err error
|
||||
)
|
||||
if r.Cmp(big.NewInt(1)) == -1 || s.Cmp(big.NewInt(1)) == -1 {
|
||||
return res, errors.New("invalid signature")
|
||||
}
|
||||
params := curve.Params()
|
||||
// Calculate h = (Q + 1 + 2 * Sqrt(Q)) / N
|
||||
// num := new(big.Int).Add(new(big.Int).Add(params.P, big.NewInt(1)), new(big.Int).Mul(big.NewInt(2), new(big.Int).Sqrt(params.P)))
|
||||
// h := new(big.Int).Div(num, params.N)
|
||||
// We are skipping this step for secp256k1 and secp256r1 because we know cofactor of these curves (h=1)
|
||||
// (see section 2.4 of http://www.secg.org/sec2-v2.pdf)
|
||||
h := 1
|
||||
for i := 0; i <= h; i++ {
|
||||
// Step 1.1: x = (n * i) + r
|
||||
Rx := new(big.Int).Mul(params.N, big.NewInt(int64(i)))
|
||||
Rx.Add(Rx, r)
|
||||
if Rx.Cmp(params.P) == 1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Steps 1.2 and 1.3: get point R (Ry)
|
||||
var R *big.Int
|
||||
if isEven {
|
||||
R, err = decodeCompressedY(Rx, 0, curve)
|
||||
} else {
|
||||
R, err = decodeCompressedY(Rx, 1, curve)
|
||||
}
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Step 1.4: check n*R is point at infinity
|
||||
nRx, nR := curve.ScalarMult(Rx, R, params.N.Bytes())
|
||||
if nRx.Sign() != 0 || nR.Sign() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 1.5: compute e
|
||||
e := hashToInt(messageHash, curve)
|
||||
|
||||
// Step 1.6: Q = r^-1 (sR-eG)
|
||||
invr := new(big.Int).ModInverse(r, params.N)
|
||||
// First term.
|
||||
invrS := new(big.Int).Mul(invr, s)
|
||||
invrS.Mod(invrS, params.N)
|
||||
sRx, sR := curve.ScalarMult(Rx, R, invrS.Bytes())
|
||||
// Second term.
|
||||
e.Neg(e)
|
||||
e.Mod(e, params.N)
|
||||
e.Mul(e, invr)
|
||||
e.Mod(e, params.N)
|
||||
minuseGx, minuseGy := curve.ScalarBaseMult(e.Bytes())
|
||||
Qx, Qy := curve.Add(sRx, sR, minuseGx, minuseGy)
|
||||
res.X = Qx
|
||||
res.Y = Qy
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// copied from crypto/ecdsa
|
||||
func hashToInt(hash []byte, c elliptic.Curve) *big.Int {
|
||||
orderBits := c.Params().N.BitLen()
|
||||
orderBytes := (orderBits + 7) / 8
|
||||
if len(hash) > orderBytes {
|
||||
hash = hash[:orderBytes]
|
||||
}
|
||||
|
||||
ret := new(big.Int).SetBytes(hash)
|
||||
excess := len(hash)*8 - orderBits
|
||||
if excess > 0 {
|
||||
ret.Rsh(ret, uint(excess))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package keys
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -85,10 +89,14 @@ func TestPubkeyToAddress(t *testing.T) {
|
|||
|
||||
func TestDecodeBytes(t *testing.T) {
|
||||
pubKey := getPubKey(t)
|
||||
decodedPubKey := &PublicKey{}
|
||||
err := decodedPubKey.DecodeBytes(pubKey.Bytes())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pubKey, decodedPubKey)
|
||||
var testBytesFunction = func(t *testing.T, bytesFunction func() []byte) {
|
||||
decodedPubKey := &PublicKey{}
|
||||
err := decodedPubKey.DecodeBytes(bytesFunction())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pubKey, decodedPubKey)
|
||||
}
|
||||
t.Run("compressed", func(t *testing.T) { testBytesFunction(t, pubKey.Bytes) })
|
||||
t.Run("uncompressed", func(t *testing.T) { testBytesFunction(t, pubKey.UncompressedBytes) })
|
||||
}
|
||||
|
||||
func TestDecodeBytesBadInfinity(t *testing.T) {
|
||||
|
@ -179,3 +187,91 @@ func TestUnmarshallJSONBadFormat(t *testing.T) {
|
|||
err := json.Unmarshal([]byte(str), actual)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRecoverSecp256r1(t *testing.T) {
|
||||
privateKey, err := NewPrivateKey()
|
||||
require.NoError(t, err)
|
||||
message := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
|
||||
messageHash := hash.Sha256(message).BytesBE()
|
||||
signature := privateKey.Sign(message)
|
||||
r := new(big.Int).SetBytes(signature[0:32])
|
||||
s := new(big.Int).SetBytes(signature[32:64])
|
||||
require.True(t, privateKey.PublicKey().Verify(signature, messageHash))
|
||||
// To test this properly, we should provide correct isEven flag. This flag denotes which one of
|
||||
// the two recovered R points in decodeCompressedY method should be chosen. Let's suppose that we
|
||||
// don't know which of them suites, so to test KeyRecover we should check both and only
|
||||
// one of them gives us the correct public key.
|
||||
recoveredKeyFalse, err := KeyRecover(elliptic.P256(), r, s, messageHash, false)
|
||||
require.NoError(t, err)
|
||||
recoveredKeyTrue, err := KeyRecover(elliptic.P256(), r, s, messageHash, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, privateKey.PublicKey().Equal(&recoveredKeyFalse) != privateKey.PublicKey().Equal(&recoveredKeyTrue))
|
||||
}
|
||||
|
||||
func TestRecoverSecp256r1Static(t *testing.T) {
|
||||
// These data were taken from the reference KeyRecoverTest: https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_ECDsa.cs#L22
|
||||
// To update this test, run the reference KeyRecover(ECCurve.Secp256r1) testcase and fetch the following data from it:
|
||||
// privateKey -> b
|
||||
// message -> messageHash
|
||||
// signatures[0] -> r
|
||||
// signatures[1] -> s
|
||||
// v -> isEven
|
||||
// Note, that C# BigInteger has different byte order from that used in Go.
|
||||
b := []byte{123, 245, 126, 56, 3, 123, 197, 199, 26, 31, 212, 186, 120, 195, 168, 153, 57, 108, 234, 49, 107, 203, 44, 207, 185, 212, 187, 129, 74, 43, 225, 69}
|
||||
privateKey, err := NewPrivateKeyFromBytes(b)
|
||||
require.NoError(t, err)
|
||||
messageHash := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
|
||||
r := new(big.Int).SetBytes([]byte{1, 85, 226, 63, 133, 113, 217, 188, 249, 22, 213, 203, 225, 199, 32, 131, 118, 23, 28, 101, 139, 211, 13, 111, 242, 158, 193, 227, 196, 106, 3, 4})
|
||||
s := new(big.Int).SetBytes([]byte{65, 174, 206, 164, 81, 34, 76, 104, 5, 49, 51, 20, 221, 183, 157, 199, 199, 47, 78, 137, 172, 99, 212, 110, 129, 72, 236, 59, 250, 81, 200, 13})
|
||||
// Just ensure it's a valid signature.
|
||||
require.True(t, privateKey.PublicKey().Verify(append(r.Bytes(), s.Bytes()...), messageHash))
|
||||
recoveredKey, err := KeyRecover(elliptic.P256(), r, s, messageHash, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, privateKey.PublicKey().Equal(&recoveredKey))
|
||||
}
|
||||
|
||||
func TestRecoverSecp256k1(t *testing.T) {
|
||||
privateKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||
message := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
|
||||
signature, err := privateKey.Sign(message)
|
||||
require.NoError(t, err)
|
||||
require.True(t, signature.Verify(message, privateKey.PubKey()))
|
||||
// To test this properly, we should provide correct isEven flag. This flag denotes which one of
|
||||
// the two recovered R points in decodeCompressedY method should be chosen. Let's suppose that we
|
||||
// don't know which of them suites, so to test KeyRecover we should check both and only
|
||||
// one of them gives us the correct public key.
|
||||
recoveredKeyFalse, err := KeyRecover(btcec.S256(), signature.R, signature.S, message, false)
|
||||
require.NoError(t, err)
|
||||
recoveredKeyTrue, err := KeyRecover(btcec.S256(), signature.R, signature.S, message, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, (privateKey.PubKey().X.Cmp(recoveredKeyFalse.X) == 0 &&
|
||||
privateKey.PubKey().Y.Cmp(recoveredKeyFalse.Y) == 0) !=
|
||||
(privateKey.PubKey().X.Cmp(recoveredKeyTrue.X) == 0 &&
|
||||
privateKey.PubKey().Y.Cmp(recoveredKeyTrue.Y) == 0))
|
||||
}
|
||||
|
||||
func TestRecoverSecp256k1Static(t *testing.T) {
|
||||
// These data were taken from the reference testcase: https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_ECDsa.cs#L22
|
||||
// To update this test, run the reference KeyRecover(ECCurve.Secp256k1) testcase and fetch the following data from it:
|
||||
// privateKey -> b
|
||||
// message -> messageHash
|
||||
// signatures[0] -> r
|
||||
// signatures[1] -> s
|
||||
// v -> isEven
|
||||
// Note, that C# BigInteger has different byte order from that used in Go.
|
||||
b := []byte{156, 3, 247, 58, 246, 250, 236, 27, 118, 60, 180, 177, 18, 92, 204, 206, 144, 245, 148, 141, 86, 212, 151, 181, 15, 113, 172, 180, 177, 228, 100, 32}
|
||||
_, publicKey := btcec.PrivKeyFromBytes(btcec.S256(), b)
|
||||
messageHash := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
|
||||
r := new(big.Int).SetBytes([]byte{88, 169, 242, 111, 210, 184, 180, 46, 67, 108, 176, 77, 57, 250, 58, 36, 110, 81, 225, 65, 90, 47, 215, 91, 27, 227, 57, 6, 9, 228, 100, 50})
|
||||
s := new(big.Int).SetBytes([]byte{86, 150, 81, 190, 17, 181, 212, 241, 184, 36, 136, 116, 232, 207, 46, 45, 149, 167, 15, 98, 113, 137, 66, 98, 214, 165, 38, 232, 98, 96, 79, 197})
|
||||
signature := btcec.Signature{
|
||||
R: r,
|
||||
S: s,
|
||||
}
|
||||
// Just ensure it's a valid signature.
|
||||
require.True(t, signature.Verify(messageHash, publicKey))
|
||||
recoveredKey, err := KeyRecover(btcec.S256(), r, s, messageHash, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, new(big.Int).SetBytes([]byte{112, 186, 29, 131, 169, 21, 212, 95, 81, 172, 201, 145, 168, 108, 129, 90, 6, 111, 80, 39, 136, 157, 15, 181, 98, 108, 133, 108, 144, 80, 23, 225}).Cmp(recoveredKey.X) == 0)
|
||||
require.True(t, new(big.Int).SetBytes([]byte{187, 102, 202, 42, 152, 133, 222, 55, 137, 228, 154, 80, 182, 35, 133, 14, 55, 165, 36, 64, 178, 55, 13, 112, 224, 143, 66, 143, 208, 18, 2, 211}).Cmp(recoveredKey.Y) == 0)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
Package crypto provides an interface to VM cryptographic instructions.
|
||||
Package crypto provides an interface to VM cryptographic instructions and syscalls.
|
||||
*/
|
||||
package crypto
|
||||
|
||||
|
@ -30,3 +30,23 @@ func Hash256(b []byte) []byte {
|
|||
func VerifySignature(msg []byte, sig []byte, pub []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Secp256k1Recover recovers public key from the given signature (r, s) on the
|
||||
// given message hash using Secp256k1 elliptic curve. Flag isEven denotes Y's
|
||||
// least significant bit in decompression algorithm. The return value is byte
|
||||
// array representation of the public key which is either empty (if it's not
|
||||
// possible to recover key) or contains 32 bytes in BE for X point (in case of
|
||||
// success). This function uses Neo.Cryptography.Secp256k1Recover syscall.
|
||||
func Secp256k1Recover(r []byte, s []byte, messageHash []byte, isEven bool) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Secp256r1Recover recovers public key from the given signature (r, s) on the
|
||||
// given message hash using Secp256r1 elliptic curve. Flag isEven denotes Y's
|
||||
// least significant bit in decompression algorithm. The return value is byte
|
||||
// array representation of the public key which is either empty (if it's not
|
||||
// possible to recover key) or contains 32 bytes in BE for X point (in case of
|
||||
// success). This function uses Neo.Cryptography.Secp256r1Recover syscall.
|
||||
func Secp256r1Recover(r []byte, s []byte, messageHash []byte, isEven bool) []byte {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,9 +8,9 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// maxArraySize is a maximums size of an array which can be decoded.
|
||||
// MaxArraySize is the maximum size of an array which can be decoded.
|
||||
// It is taken from https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs#L130
|
||||
const maxArraySize = 0x1000000
|
||||
const MaxArraySize = 0x1000000
|
||||
|
||||
// BinReader is a convenient wrapper around a io.Reader and err object.
|
||||
// Used to simplify error handling when reading into a struct with many fields.
|
||||
|
@ -110,7 +110,7 @@ func (r *BinReader) ReadArray(t interface{}, maxSize ...int) {
|
|||
elemType := sliceType.Elem()
|
||||
isPtr := elemType.Kind() == reflect.Ptr
|
||||
|
||||
ms := maxArraySize
|
||||
ms := MaxArraySize
|
||||
if len(maxSize) != 0 {
|
||||
ms = maxSize[0]
|
||||
}
|
||||
|
@ -168,8 +168,16 @@ func (r *BinReader) ReadVarUint() uint64 {
|
|||
|
||||
// ReadVarBytes reads the next set of bytes from the underlying reader.
|
||||
// ReadVarUInt() is used to determine how large that slice is
|
||||
func (r *BinReader) ReadVarBytes() []byte {
|
||||
func (r *BinReader) ReadVarBytes(maxSize ...int) []byte {
|
||||
n := r.ReadVarUint()
|
||||
ms := MaxArraySize
|
||||
if len(maxSize) != 0 {
|
||||
ms = maxSize[0]
|
||||
}
|
||||
if n > uint64(ms) {
|
||||
r.Err = fmt.Errorf("byte-slice is too big (%d)", n)
|
||||
return nil
|
||||
}
|
||||
b := make([]byte, n)
|
||||
r.ReadBytes(b)
|
||||
return b
|
||||
|
|
|
@ -143,6 +143,35 @@ func TestBufBinWriter_Len(t *testing.T) {
|
|||
require.Equal(t, 1, bw.Len())
|
||||
}
|
||||
|
||||
func TestBinReader_ReadVarBytes(t *testing.T) {
|
||||
buf := make([]byte, 11)
|
||||
for i := range buf {
|
||||
buf[i] = byte(i)
|
||||
}
|
||||
w := NewBufBinWriter()
|
||||
w.WriteVarBytes(buf)
|
||||
require.NoError(t, w.Err)
|
||||
data := w.Bytes()
|
||||
|
||||
t.Run("NoArguments", func(t *testing.T) {
|
||||
r := NewBinReaderFromBuf(data)
|
||||
actual := r.ReadVarBytes()
|
||||
require.NoError(t, r.Err)
|
||||
require.Equal(t, buf, actual)
|
||||
})
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
r := NewBinReaderFromBuf(data)
|
||||
actual := r.ReadVarBytes(11)
|
||||
require.NoError(t, r.Err)
|
||||
require.Equal(t, buf, actual)
|
||||
})
|
||||
t.Run("Bad", func(t *testing.T) {
|
||||
r := NewBinReaderFromBuf(data)
|
||||
r.ReadVarBytes(10)
|
||||
require.Error(t, r.Err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriterErrHandling(t *testing.T) {
|
||||
var badio = &badRW{}
|
||||
bw := NewBinWriterFromIO(badio)
|
||||
|
|
|
@ -59,6 +59,9 @@ func (chain *testChain) AddBlock(block *block.Block) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
func (chain *testChain) AddStateRoot(r *state.MPTRoot) error {
|
||||
panic("TODO")
|
||||
}
|
||||
func (chain *testChain) BlockHeight() uint32 {
|
||||
return atomic.LoadUint32(&chain.blockheight)
|
||||
}
|
||||
|
@ -105,6 +108,12 @@ func (chain testChain) GetEnrollments() ([]*state.Validator, error) {
|
|||
func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
func (chain testChain) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem {
|
||||
panic("TODO")
|
||||
}
|
||||
|
@ -145,7 +154,9 @@ func (chain testChain) IsLowPriority(util.Fixed8) bool {
|
|||
func (chain testChain) PoolTx(*transaction.Transaction) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (chain testChain) StateHeight() uint32 {
|
||||
panic("TODO")
|
||||
}
|
||||
func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||
"github.com/nspcc-dev/neo-go/pkg/consensus"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
|
@ -59,12 +60,15 @@ const (
|
|||
CMDGetBlocks CommandType = "getblocks"
|
||||
CMDGetData CommandType = "getdata"
|
||||
CMDGetHeaders CommandType = "getheaders"
|
||||
CMDGetRoots CommandType = "getroots"
|
||||
CMDHeaders CommandType = "headers"
|
||||
CMDInv CommandType = "inv"
|
||||
CMDMempool CommandType = "mempool"
|
||||
CMDMerkleBlock CommandType = "merkleblock"
|
||||
CMDPing CommandType = "ping"
|
||||
CMDPong CommandType = "pong"
|
||||
CMDRoots CommandType = "roots"
|
||||
CMDStateRoot CommandType = "stateroot"
|
||||
CMDTX CommandType = "tx"
|
||||
CMDUnknown CommandType = "unknown"
|
||||
CMDVerack CommandType = "verack"
|
||||
|
@ -124,6 +128,8 @@ func (m *Message) CommandType() CommandType {
|
|||
return CMDGetData
|
||||
case "getheaders":
|
||||
return CMDGetHeaders
|
||||
case "getroots":
|
||||
return CMDGetRoots
|
||||
case "headers":
|
||||
return CMDHeaders
|
||||
case "inv":
|
||||
|
@ -136,6 +142,10 @@ func (m *Message) CommandType() CommandType {
|
|||
return CMDPing
|
||||
case "pong":
|
||||
return CMDPong
|
||||
case "roots":
|
||||
return CMDRoots
|
||||
case "stateroot":
|
||||
return CMDStateRoot
|
||||
case "tx":
|
||||
return CMDTX
|
||||
case "verack":
|
||||
|
@ -191,6 +201,8 @@ func (m *Message) decodePayload(br *io.BinReader) error {
|
|||
fallthrough
|
||||
case CMDGetHeaders:
|
||||
p = &payload.GetBlocks{}
|
||||
case CMDGetRoots:
|
||||
p = &payload.GetStateRoots{}
|
||||
case CMDHeaders:
|
||||
p = &payload.Headers{}
|
||||
case CMDTX:
|
||||
|
@ -199,6 +211,10 @@ func (m *Message) decodePayload(br *io.BinReader) error {
|
|||
p = &payload.MerkleBlock{}
|
||||
case CMDPing, CMDPong:
|
||||
p = &payload.Ping{}
|
||||
case CMDRoots:
|
||||
p = &payload.StateRoots{}
|
||||
case CMDStateRoot:
|
||||
p = &state.MPTRoot{}
|
||||
default:
|
||||
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ func (i InventoryType) String() string {
|
|||
return "TX"
|
||||
case 0x02:
|
||||
return "block"
|
||||
case StateRootType:
|
||||
return "stateroot"
|
||||
case 0xe0:
|
||||
return "consensus"
|
||||
default:
|
||||
|
@ -27,13 +29,14 @@ func (i InventoryType) String() string {
|
|||
|
||||
// Valid returns true if the inventory (type) is known.
|
||||
func (i InventoryType) Valid() bool {
|
||||
return i == BlockType || i == TXType || i == ConsensusType
|
||||
return i == BlockType || i == TXType || i == ConsensusType || i == StateRootType
|
||||
}
|
||||
|
||||
// List of valid InventoryTypes.
|
||||
const (
|
||||
TXType InventoryType = 0x01 // 1
|
||||
BlockType InventoryType = 0x02 // 2
|
||||
StateRootType InventoryType = 0x03 // 3
|
||||
ConsensusType InventoryType = 0xe0 // 224
|
||||
)
|
||||
|
||||
|
|
43
pkg/network/payload/state_root.go
Normal file
43
pkg/network/payload/state_root.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
)
|
||||
|
||||
// MaxStateRootsAllowed is a maxumum amount of state roots
|
||||
// which can be sent in a single payload.
|
||||
const MaxStateRootsAllowed = 2000
|
||||
|
||||
// StateRoots contains multiple StateRoots.
|
||||
type StateRoots struct {
|
||||
Roots []state.MPTRoot
|
||||
}
|
||||
|
||||
// GetStateRoots represents request for state roots.
|
||||
type GetStateRoots struct {
|
||||
Start uint32
|
||||
Count uint32
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (s *StateRoots) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteArray(s.Roots)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (s *StateRoots) DecodeBinary(r *io.BinReader) {
|
||||
r.ReadArray(&s.Roots, MaxStateRootsAllowed)
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (g *GetStateRoots) DecodeBinary(r *io.BinReader) {
|
||||
g.Start = r.ReadU32LE()
|
||||
g.Count = r.ReadU32LE()
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (g *GetStateRoots) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteU32LE(g.Start)
|
||||
w.WriteU32LE(g.Count)
|
||||
}
|
51
pkg/network/payload/state_root_test.go
Normal file
51
pkg/network/payload/state_root_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
)
|
||||
|
||||
func TestStateRoots_Serializable(t *testing.T) {
|
||||
expected := &StateRoots{
|
||||
Roots: []state.MPTRoot{
|
||||
{
|
||||
MPTRootBase: state.MPTRootBase{
|
||||
Index: rand.Uint32(),
|
||||
PrevHash: random.Uint256(),
|
||||
Root: random.Uint256(),
|
||||
},
|
||||
Witness: &transaction.Witness{
|
||||
InvocationScript: random.Bytes(10),
|
||||
VerificationScript: random.Bytes(11),
|
||||
},
|
||||
},
|
||||
{
|
||||
MPTRootBase: state.MPTRootBase{
|
||||
Index: rand.Uint32(),
|
||||
PrevHash: random.Uint256(),
|
||||
Root: random.Uint256(),
|
||||
},
|
||||
Witness: &transaction.Witness{
|
||||
InvocationScript: random.Bytes(10),
|
||||
VerificationScript: random.Bytes(11),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testserdes.EncodeDecodeBinary(t, expected, new(StateRoots))
|
||||
}
|
||||
|
||||
func TestGetStateRoots_Serializable(t *testing.T) {
|
||||
expected := &GetStateRoots{
|
||||
Start: rand.Uint32(),
|
||||
Count: rand.Uint32(),
|
||||
}
|
||||
|
||||
testserdes.EncodeDecodeBinary(t, expected, new(GetStateRoots))
|
||||
}
|
|
@ -13,6 +13,8 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/consensus"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/cache"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
@ -28,6 +30,7 @@ const (
|
|||
maxBlockBatch = 200
|
||||
maxAddrsToSend = 200
|
||||
minPoolCount = 30
|
||||
stateRootCacheSize = 100
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -66,6 +69,7 @@ type (
|
|||
|
||||
transactions chan *transaction.Transaction
|
||||
|
||||
stateCache cache.HashCache
|
||||
consensusStarted *atomic.Bool
|
||||
|
||||
log *zap.Logger
|
||||
|
@ -98,6 +102,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (*
|
|||
unregister: make(chan peerDrop),
|
||||
peers: make(map[Peer]bool),
|
||||
consensusStarted: atomic.NewBool(false),
|
||||
stateCache: *cache.NewFIFOCache(stateRootCacheSize),
|
||||
log: log,
|
||||
transactions: make(chan *transaction.Transaction, 64),
|
||||
}
|
||||
|
@ -469,6 +474,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
|
|||
cp := s.consensus.GetPayload(h)
|
||||
return cp != nil
|
||||
},
|
||||
payload.StateRootType: s.stateCache.Has,
|
||||
}
|
||||
if exists := typExists[inv.Type]; exists != nil {
|
||||
for _, hash := range inv.Hashes {
|
||||
|
@ -507,6 +513,11 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
|
|||
if err == nil {
|
||||
msg = s.MkMsg(CMDBlock, b)
|
||||
}
|
||||
case payload.StateRootType:
|
||||
r := s.stateCache.Get(hash)
|
||||
if r != nil {
|
||||
msg = s.MkMsg(CMDStateRoot, r.(*state.MPTRoot))
|
||||
}
|
||||
case payload.ConsensusType:
|
||||
if cp := s.consensus.GetPayload(hash); cp != nil {
|
||||
msg = s.MkMsg(CMDConsensus, cp)
|
||||
|
@ -589,6 +600,87 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error {
|
|||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
|
||||
// handleGetRootsCmd processees `getroots` request.
|
||||
func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error {
|
||||
cfg := s.chain.GetConfig()
|
||||
if !cfg.EnableStateRoot || gr.Start < cfg.StateRootEnableIndex {
|
||||
return nil
|
||||
}
|
||||
count := gr.Count
|
||||
if count > payload.MaxStateRootsAllowed {
|
||||
count = payload.MaxStateRootsAllowed
|
||||
}
|
||||
var rs payload.StateRoots
|
||||
for height := gr.Start; height < gr.Start+gr.Count; height++ {
|
||||
r, err := s.chain.GetStateRoot(height)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.Flag == state.Verified {
|
||||
rs.Roots = append(rs.Roots, r.MPTRoot)
|
||||
}
|
||||
}
|
||||
msg := s.MkMsg(CMDRoots, &rs)
|
||||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
|
||||
// handleStateRootsCmd processees `roots` request.
|
||||
func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error {
|
||||
if !s.chain.GetConfig().EnableStateRoot {
|
||||
return nil
|
||||
}
|
||||
h := s.chain.StateHeight()
|
||||
if h < s.chain.GetConfig().StateRootEnableIndex {
|
||||
h = s.chain.GetConfig().StateRootEnableIndex
|
||||
}
|
||||
for i := range rs.Roots {
|
||||
if rs.Roots[i].Index <= h {
|
||||
continue
|
||||
}
|
||||
_ = s.chain.AddStateRoot(&rs.Roots[i])
|
||||
}
|
||||
// request more state roots from peer if needed
|
||||
return s.requestStateRoot(p)
|
||||
}
|
||||
|
||||
// requestStateRoot sends `getroots` message to get verified state roots.
|
||||
func (s *Server) requestStateRoot(p Peer) error {
|
||||
stateHeight := s.chain.StateHeight()
|
||||
hdrHeight := s.chain.BlockHeight()
|
||||
enableIndex := s.chain.GetConfig().StateRootEnableIndex
|
||||
if hdrHeight < enableIndex {
|
||||
return nil
|
||||
}
|
||||
if stateHeight < enableIndex {
|
||||
stateHeight = enableIndex - 1
|
||||
}
|
||||
count := uint32(payload.MaxStateRootsAllowed)
|
||||
if diff := hdrHeight - stateHeight; diff < count {
|
||||
count = diff
|
||||
}
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
gr := &payload.GetStateRoots{
|
||||
Start: stateHeight + 1,
|
||||
Count: count,
|
||||
}
|
||||
return p.EnqueueP2PMessage(s.MkMsg(CMDGetRoots, gr))
|
||||
}
|
||||
|
||||
// handleStateRootCmd processees `stateroot` request.
|
||||
func (s *Server) handleStateRootCmd(r *state.MPTRoot) error {
|
||||
if !s.chain.GetConfig().EnableStateRoot {
|
||||
return nil
|
||||
}
|
||||
// we ignore error, because there is nothing wrong if we already have this state root
|
||||
err := s.chain.AddStateRoot(r)
|
||||
if err == nil && !s.stateCache.Has(r.Hash()) {
|
||||
s.stateCache.Add(r)
|
||||
s.broadcastMessage(s.MkMsg(CMDStateRoot, r))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleConsensusCmd processes received consensus payload.
|
||||
// It never returns an error.
|
||||
func (s *Server) handleConsensusCmd(cp *consensus.Payload) error {
|
||||
|
@ -697,6 +789,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
case CMDGetHeaders:
|
||||
gh := msg.Payload.(*payload.GetBlocks)
|
||||
return s.handleGetHeadersCmd(peer, gh)
|
||||
case CMDGetRoots:
|
||||
gr := msg.Payload.(*payload.GetStateRoots)
|
||||
return s.handleGetRootsCmd(peer, gr)
|
||||
case CMDHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
go s.handleHeadersCmd(peer, headers)
|
||||
|
@ -718,6 +813,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
case CMDPong:
|
||||
pong := msg.Payload.(*payload.Ping)
|
||||
return s.handlePong(peer, pong)
|
||||
case CMDRoots:
|
||||
rs := msg.Payload.(*payload.StateRoots)
|
||||
return s.handleRootsCmd(peer, rs)
|
||||
case CMDStateRoot:
|
||||
r := msg.Payload.(*state.MPTRoot)
|
||||
return s.handleStateRootCmd(r)
|
||||
case CMDVersion, CMDVerack:
|
||||
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
|
||||
}
|
||||
|
@ -741,11 +842,20 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleNewPayload(p *consensus.Payload) {
|
||||
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()}))
|
||||
// It's high priority because it directly affects consensus process,
|
||||
// even though it's just an inv.
|
||||
s.broadcastHPMessage(msg)
|
||||
func (s *Server) handleNewPayload(item cache.Hashable) {
|
||||
switch p := item.(type) {
|
||||
case *consensus.Payload:
|
||||
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()}))
|
||||
// It's high priority because it directly affects consensus process,
|
||||
// even though it's just an inv.
|
||||
s.broadcastHPMessage(msg)
|
||||
case *state.MPTRoot:
|
||||
s.stateCache.Add(p)
|
||||
msg := s.MkMsg(CMDStateRoot, p)
|
||||
s.broadcastMessage(msg)
|
||||
default:
|
||||
s.log.Warn("unknown item type", zap.String("type", fmt.Sprintf("%T", p)))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) requestTx(hashes ...util.Uint256) {
|
||||
|
|
|
@ -251,6 +251,9 @@ func (p *TCPPeer) StartProtocol() {
|
|||
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
|
||||
err = p.server.requestBlocks(p)
|
||||
}
|
||||
if err == nil && p.server.chain.GetConfig().EnableStateRoot {
|
||||
err = p.server.requestStateRoot(p)
|
||||
}
|
||||
if err == nil {
|
||||
timer.Reset(p.server.ProtoTickInterval)
|
||||
}
|
||||
|
|
|
@ -181,8 +181,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.TxFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.TxFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -196,8 +196,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.NotificationFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.NotificationFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -211,8 +211,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.ExecutionFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.ExecutionFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
|
|
@ -61,12 +61,17 @@ const (
|
|||
ExecutionFilterT
|
||||
)
|
||||
|
||||
var errMissingParameter = errors.New("parameter is missing")
|
||||
|
||||
func (p Param) String() string {
|
||||
return fmt.Sprintf("%v", p.Value)
|
||||
}
|
||||
|
||||
// GetString returns string value of the parameter.
|
||||
func (p Param) GetString() (string, error) {
|
||||
func (p *Param) GetString() (string, error) {
|
||||
if p == nil {
|
||||
return "", errMissingParameter
|
||||
}
|
||||
str, ok := p.Value.(string)
|
||||
if !ok {
|
||||
return "", errors.New("not a string")
|
||||
|
@ -74,8 +79,26 @@ func (p Param) GetString() (string, error) {
|
|||
return str, nil
|
||||
}
|
||||
|
||||
// GetBoolean returns boolean value of the parameter.
|
||||
func (p *Param) GetBoolean() bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
switch p.Type {
|
||||
case NumberT:
|
||||
return p.Value != 0
|
||||
case StringT:
|
||||
return p.Value != ""
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// GetInt returns int value of te parameter.
|
||||
func (p Param) GetInt() (int, error) {
|
||||
func (p *Param) GetInt() (int, error) {
|
||||
if p == nil {
|
||||
return 0, errMissingParameter
|
||||
}
|
||||
i, ok := p.Value.(int)
|
||||
if ok {
|
||||
return i, nil
|
||||
|
@ -86,7 +109,10 @@ func (p Param) GetInt() (int, error) {
|
|||
}
|
||||
|
||||
// GetArray returns a slice of Params stored in the parameter.
|
||||
func (p Param) GetArray() ([]Param, error) {
|
||||
func (p *Param) GetArray() ([]Param, error) {
|
||||
if p == nil {
|
||||
return nil, errMissingParameter
|
||||
}
|
||||
a, ok := p.Value.([]Param)
|
||||
if !ok {
|
||||
return nil, errors.New("not an array")
|
||||
|
@ -95,7 +121,7 @@ func (p Param) GetArray() ([]Param, error) {
|
|||
}
|
||||
|
||||
// GetUint256 returns Uint256 value of the parameter.
|
||||
func (p Param) GetUint256() (util.Uint256, error) {
|
||||
func (p *Param) GetUint256() (util.Uint256, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return util.Uint256{}, err
|
||||
|
@ -105,7 +131,7 @@ func (p Param) GetUint256() (util.Uint256, error) {
|
|||
}
|
||||
|
||||
// GetUint160FromHex returns Uint160 value of the parameter encoded in hex.
|
||||
func (p Param) GetUint160FromHex() (util.Uint160, error) {
|
||||
func (p *Param) GetUint160FromHex() (util.Uint160, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return util.Uint160{}, err
|
||||
|
@ -119,7 +145,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) {
|
|||
|
||||
// GetUint160FromAddress returns Uint160 value of the parameter that was
|
||||
// supplied as an address.
|
||||
func (p Param) GetUint160FromAddress() (util.Uint160, error) {
|
||||
func (p *Param) GetUint160FromAddress() (util.Uint160, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return util.Uint160{}, err
|
||||
|
@ -129,7 +155,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) {
|
|||
}
|
||||
|
||||
// GetFuncParam returns current parameter as a function call parameter.
|
||||
func (p Param) GetFuncParam() (FuncParam, error) {
|
||||
func (p *Param) GetFuncParam() (FuncParam, error) {
|
||||
if p == nil {
|
||||
return FuncParam{}, errMissingParameter
|
||||
}
|
||||
fp, ok := p.Value.(FuncParam)
|
||||
if !ok {
|
||||
return FuncParam{}, errors.New("not a function parameter")
|
||||
|
@ -139,7 +168,7 @@ func (p Param) GetFuncParam() (FuncParam, error) {
|
|||
|
||||
// GetBytesHex returns []byte value of the parameter if
|
||||
// it is a hex-encoded string.
|
||||
func (p Param) GetBytesHex() ([]byte, error) {
|
||||
func (p *Param) GetBytesHex() ([]byte, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -163,6 +192,11 @@ func (p *Param) UnmarshalJSON(data []byte) error {
|
|||
{ArrayT, &[]Param{}},
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
p.Type = defaultT
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, cur := range attempts {
|
||||
r := bytes.NewReader(data)
|
||||
jd := json.NewDecoder(r)
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestParam_UnmarshalJSON(t *testing.T) {
|
||||
msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}],
|
||||
msg := `["str1", 123, null, ["str2", 3], [{"type": "String", "value": "jajaja"}],
|
||||
{"type": "MinerTransaction"},
|
||||
{"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"},
|
||||
{"state": "HALT"}]`
|
||||
|
@ -29,6 +29,9 @@ func TestParam_UnmarshalJSON(t *testing.T) {
|
|||
Type: NumberT,
|
||||
Value: 123,
|
||||
},
|
||||
{
|
||||
Type: defaultT,
|
||||
},
|
||||
{
|
||||
Type: ArrayT,
|
||||
Value: []Param{
|
||||
|
|
|
@ -7,20 +7,19 @@ type (
|
|||
|
||||
// Value returns the param struct for the given
|
||||
// index if it exists.
|
||||
func (p Params) Value(index int) (*Param, bool) {
|
||||
func (p Params) Value(index int) *Param {
|
||||
if len(p) > index {
|
||||
return &p[index], true
|
||||
return &p[index]
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValueWithType returns the param struct at the given index if it
|
||||
// exists and matches the given type.
|
||||
func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) {
|
||||
if val, ok := p.Value(index); ok && val.Type == valType {
|
||||
return val, true
|
||||
func (p Params) ValueWithType(index int, valType paramType) *Param {
|
||||
if val := p.Value(index); val != nil && val.Type == valType {
|
||||
return val
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return nil
|
||||
}
|
||||
|
|
122
pkg/rpc/response/result/mpt.go
Normal file
122
pkg/rpc/response/result/mpt.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package result
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
)
|
||||
|
||||
// StateHeight is a result of getstateheight RPC.
|
||||
type StateHeight struct {
|
||||
BlockHeight uint32 `json:"blockHeight"`
|
||||
StateHeight uint32 `json:"stateHeight"`
|
||||
}
|
||||
|
||||
// ProofWithKey represens key-proof pair.
|
||||
type ProofWithKey struct {
|
||||
Key []byte
|
||||
Proof [][]byte
|
||||
}
|
||||
|
||||
// GetProof is a result of getproof RPC.
|
||||
type GetProof struct {
|
||||
Result ProofWithKey `json:"proof"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// VerifyProof is a result of verifyproof RPC.
|
||||
// nil Value is considered invalid.
|
||||
type VerifyProof struct {
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (p *ProofWithKey) MarshalJSON() ([]byte, error) {
|
||||
w := io.NewBufBinWriter()
|
||||
p.EncodeBinary(w.BinWriter)
|
||||
if w.Err != nil {
|
||||
return nil, w.Err
|
||||
}
|
||||
return []byte(`"` + hex.EncodeToString(w.Bytes()) + `"`), nil
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (p *ProofWithKey) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteVarBytes(p.Key)
|
||||
w.WriteVarUint(uint64(len(p.Proof)))
|
||||
for i := range p.Proof {
|
||||
w.WriteVarBytes(p.Proof[i])
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (p *ProofWithKey) DecodeBinary(r *io.BinReader) {
|
||||
p.Key = r.ReadVarBytes()
|
||||
sz := r.ReadVarUint()
|
||||
for i := uint64(0); i < sz; i++ {
|
||||
p.Proof = append(p.Proof, r.ReadVarBytes())
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (p *ProofWithKey) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.FromString(s)
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (p *ProofWithKey) String() string {
|
||||
w := io.NewBufBinWriter()
|
||||
p.EncodeBinary(w.BinWriter)
|
||||
return hex.EncodeToString(w.Bytes())
|
||||
}
|
||||
|
||||
// FromString decodes p from hex-encoded string.
|
||||
func (p *ProofWithKey) FromString(s string) error {
|
||||
rawProof, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r := io.NewBinReaderFromBuf(rawProof)
|
||||
p.DecodeBinary(r)
|
||||
return r.Err
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (p *VerifyProof) MarshalJSON() ([]byte, error) {
|
||||
if p.Value == nil {
|
||||
return []byte(`"invalid"`), nil
|
||||
}
|
||||
return []byte(`{"value":"` + hex.EncodeToString(p.Value) + `"}`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (p *VerifyProof) UnmarshalJSON(data []byte) error {
|
||||
if bytes.Equal(data, []byte(`"invalid"`)) {
|
||||
p.Value = nil
|
||||
return nil
|
||||
}
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(m) != 1 {
|
||||
return errors.New("must have single key")
|
||||
}
|
||||
v, ok := m["value"]
|
||||
if !ok {
|
||||
return errors.New("invalid json")
|
||||
}
|
||||
b, err := hex.DecodeString(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Value = b
|
||||
return nil
|
||||
}
|
68
pkg/rpc/response/result/mpt_test.go
Normal file
68
pkg/rpc/response/result/mpt_test.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package result
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testProofWithKey() *ProofWithKey {
|
||||
return &ProofWithKey{
|
||||
Key: random.Bytes(10),
|
||||
Proof: [][]byte{
|
||||
random.Bytes(12),
|
||||
random.Bytes(0),
|
||||
random.Bytes(34),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProof_MarshalJSON(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
p := &GetProof{
|
||||
Result: *testProofWithKey(),
|
||||
Success: true,
|
||||
}
|
||||
testserdes.MarshalUnmarshalJSON(t, p, new(GetProof))
|
||||
})
|
||||
t.Run("Compatibility", func(t *testing.T) {
|
||||
js := []byte(`{
|
||||
"proof" : "25ddeb9aa1bfc353c9c54e21dffb470f65d9c22a0662616c616e63654f70000000000000000708fd12020020666eaa8a6e75d43a97d76e72b605c7e05189f0c57ec19d84acdb75810f18239d202c83028ce3d7abcf4e4f95d05fbfdfa5e18bde3a8fbb65a57559d6b5ea09425c2090c40d440744a848e3b407a00e4efb692a957245a1efc9cb8496cb05fd328ee620dd2652bf25dfc3ad5fee7b200ccf3e3ae50772ff8ed58907e4dab8e7d4b2489720d8a5d5ed75b5b0f256d0a2cf5c220b4ddae2a228ef0fc0212b689f3811dfa94620342cc0d73fabd2440ed2cc735a9608391a510e1981b321a9f4258682706adc9620ced036e52f39387b9c58ade7bf8c3ca8959b64d8031d36d9b1c62f3f1c51c7cb2031072c7c801b5c1614dae441383a65344acd238f13db28ff0a39c0626e597f002062552d64c616d8b2a6a93d22936055110c0065728aa2b4fbf4d76b108390b474203322d3c93c741674a307cf6455e77c02ceeda307d4ec23fd809a2a420b4243f82052ab92a9cedc6716ad4c66a8a3e423b195b05bdebde456f992bff48f2561e99720e6379995e7053823b8ba8fb8af9623cf48e89f60c989598445df5e711db42a6f20192894ed637e86561ff6a4b8dea4539dee8bddb2fb20bf4ae3499852985c88b120e0005edd09f2335aa6b59ff4723e1262b2192adaa5e3e56f79e662f07041f04c2033577f3e2c5bb0e58746980a07cdfad2f872e2b9a10bcc27b7c678c85576df8420f0f04180d15b6eaa0c43e62380084c75ad773d790700a7120c6c4da1fc51693000fd720100209648e8f10a5ff4c209009b9a09697babbe1b2150d0948c1970a560282a1bfa4720988af8f34859dd8309bffea0b1dff9c8cef0b9b0d6a1852d40786627729ae7be00206ebf4f1b7861bca041cbb8feca75158511ca43a1810d17e1e3017468e8cef0de20cac93064090a7da09f8202c17d1e6cbb9a16eb43afcb032e80719cbf05b3446d2019b76a10b91fb99ec08814e8108e5490b879fb09a190cb2c129dfd98335bd5de000020b1da1198bacacf2adc0d863929d77c285ce3a26e736203d0c0a69a1312255fb2207ee8aa092f49348bd89f9c4bf004b0bee2241a2d0acfe7b3ce08e414b04a5717205b0dda71eac8a4e4cdc6a7b939748c0a78abb54f2547a780e6df67b25530330f000020fc358fb9d1e0d36461e015ac8e35f97072a9f9e750a3c25722a2b1a858fcb82d203c52c9fac6d4694b351390158334a9166bc3478ceb9bea2b0b244915f918239e20d526344a24ff19ee6a9f5c5beb833f4eb6d51191590350e26fa50b138493473f005200000000000000000000002077c404fec0a4265568951dbd096572787d109fab105213f4f292a5f53ce72fca00000020b8d1c7a386eaba83ce83ee0700d4ca9b86e75d147d670ea05123e438231d895000004801250b090a0a010b0f0c0305030c090c05040e02010d0f0f0b0407000f06050d090c02020a0006202af2097cf9d3f42e49f6b3c3dd254e7cbdab3485b029721cbbbf1ad0455a810852000000000000002055170506f4b18bc573a909b51cb21bdd5d303ec511f6cdfb1c6a1ab8d8a1dad020ee774c1b9fe1d8ea8d05823837d959da48af74f384d52f06c42c9d146c5258e300000000000000000072000000204457a6fe530ee953ad1f9caf63daf7f86719c9986df2d0b6917021eb379800f00020406bfc79da4ba6f37452a679d13cca252585d34f7e94a480b047bad9427f233e00000000201ce15a2373d28e0dc5f2000cf308f155d06f72070a29e5af1528c8f05f29d248000000000000004301200601060c0601060e06030605040f0700000000000000000000000000000000072091b83866bbd7450115b462e8d48601af3c3e9a35e7018d2b98a23e107c15c200090307000410a328e800",
|
||||
"success" : true
|
||||
}`)
|
||||
|
||||
var p GetProof
|
||||
require.NoError(t, json.Unmarshal(js, &p))
|
||||
require.Equal(t, 8, len(p.Result.Proof))
|
||||
for i := range p.Result.Proof { // smoke test that every chunk is correctly encoded node
|
||||
r := io.NewBinReaderFromBuf(p.Result.Proof[i])
|
||||
var n mpt.NodeObject
|
||||
n.DecodeBinary(r)
|
||||
require.NoError(t, r.Err)
|
||||
require.NotNil(t, n.Node)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProofWithKey_EncodeString(t *testing.T) {
|
||||
expected := testProofWithKey()
|
||||
var actual ProofWithKey
|
||||
require.NoError(t, actual.FromString(expected.String()))
|
||||
require.Equal(t, expected, &actual)
|
||||
}
|
||||
|
||||
func TestVerifyProof_MarshalJSON(t *testing.T) {
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
vp := &VerifyProof{random.Bytes(100)}
|
||||
testserdes.MarshalUnmarshalJSON(t, vp, new(VerifyProof))
|
||||
})
|
||||
t.Run("NoValue", func(t *testing.T) {
|
||||
vp := new(VerifyProof)
|
||||
testserdes.MarshalUnmarshalJSON(t, vp, &VerifyProof{[]byte{1, 2, 3}})
|
||||
})
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
|
@ -95,6 +96,9 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon
|
|||
"getpeers": (*Server).getPeers,
|
||||
"getrawmempool": (*Server).getRawMempool,
|
||||
"getrawtransaction": (*Server).getrawtransaction,
|
||||
"getproof": (*Server).getProof,
|
||||
"getstateheight": (*Server).getStateHeight,
|
||||
"getstateroot": (*Server).getStateRoot,
|
||||
"getstorage": (*Server).getStorage,
|
||||
"gettransactionheight": (*Server).getTransactionHeight,
|
||||
"gettxout": (*Server).getTxOut,
|
||||
|
@ -108,6 +112,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon
|
|||
"sendrawtransaction": (*Server).sendrawtransaction,
|
||||
"submitblock": (*Server).submitBlock,
|
||||
"validateaddress": (*Server).validateAddress,
|
||||
"verifyproof": (*Server).verifyProof,
|
||||
}
|
||||
|
||||
var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){
|
||||
|
@ -389,8 +394,8 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er
|
|||
func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var hash util.Uint256
|
||||
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
param := reqParams.Value(0)
|
||||
if param == nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
|
@ -416,7 +421,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
|
|||
return nil, response.NewInternalServerError(fmt.Sprintf("Problem locating block with hash: %s", hash), err)
|
||||
}
|
||||
|
||||
if len(reqParams) == 2 && reqParams[1].Value == 1 {
|
||||
if reqParams.Value(1).GetBoolean() {
|
||||
return result.NewBlock(block, s.chain), nil
|
||||
}
|
||||
writer := io.NewBufBinWriter()
|
||||
|
@ -425,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
|
|||
}
|
||||
|
||||
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||
if !ok {
|
||||
param := reqParams.ValueWithType(0, request.NumberT)
|
||||
if param == nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
num, err := s.blockHeightFromParam(param)
|
||||
|
@ -463,20 +468,15 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error)
|
|||
}
|
||||
|
||||
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
param := reqParams.Value(0)
|
||||
if param == nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
return validateAddress(param.Value), nil
|
||||
}
|
||||
|
||||
func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
paramAssetID, err := param.GetUint256()
|
||||
paramAssetID, err := reqParams.ValueWithType(0, request.StringT).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -490,12 +490,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response
|
|||
|
||||
// getApplicationLog returns the contract log based on the specified txid.
|
||||
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
txHash, err := param.GetUint256()
|
||||
txHash, err := reqParams.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -522,10 +517,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp
|
|||
}
|
||||
|
||||
func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
p := ps.ValueWithType(0, request.StringT)
|
||||
u, err := p.GetUint160FromAddress()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -574,11 +566,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error)
|
|||
}
|
||||
|
||||
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
u, err := p.GetUint160FromHex()
|
||||
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -607,11 +595,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro
|
|||
}
|
||||
|
||||
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
u, err := p.GetUint160FromAddress()
|
||||
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -708,25 +692,96 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
|
|||
return d, nil
|
||||
}
|
||||
|
||||
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
|
||||
param, ok := ps.Value(0)
|
||||
if !ok {
|
||||
func (s *Server) getProof(ps request.Params) (interface{}, *response.Error) {
|
||||
root, err := ps.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
sc, err := ps.Value(1).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
sc = sc.Reverse()
|
||||
key, err := ps.Value(2).GetBytesHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
skey := mpt.ToNeoStorageKey(append(sc.BytesBE(), key...))
|
||||
proof, err := s.chain.GetStateProof(root, skey)
|
||||
return &result.GetProof{
|
||||
Result: result.ProofWithKey{
|
||||
Key: skey,
|
||||
Proof: proof,
|
||||
},
|
||||
Success: err == nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
scriptHash, err := param.GetUint160FromHex()
|
||||
func (s *Server) verifyProof(ps request.Params) (interface{}, *response.Error) {
|
||||
root, err := ps.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
proofStr, err := ps.Value(1).GetString()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
var p result.ProofWithKey
|
||||
if err := p.FromString(proofStr); err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
vp := new(result.VerifyProof)
|
||||
val, ok := mpt.VerifyProof(root, p.Key, p.Proof)
|
||||
if ok {
|
||||
var si state.StorageItem
|
||||
r := io.NewBinReaderFromBuf(val[1:])
|
||||
si.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
return nil, response.NewInternalServerError("invalid item in trie", r.Err)
|
||||
}
|
||||
vp.Value = si.Value
|
||||
}
|
||||
return vp, nil
|
||||
}
|
||||
|
||||
func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) {
|
||||
return &result.StateHeight{
|
||||
BlockHeight: s.chain.BlockHeight(),
|
||||
StateHeight: s.chain.StateHeight(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getStateRoot(ps request.Params) (interface{}, *response.Error) {
|
||||
p := ps.Value(0)
|
||||
if p == nil {
|
||||
return nil, response.NewRPCError("Invalid parameter.", "", nil)
|
||||
}
|
||||
var rt *state.MPTRootState
|
||||
var h util.Uint256
|
||||
height, err := p.GetInt()
|
||||
if err == nil {
|
||||
rt, err = s.chain.GetStateRoot(uint32(height))
|
||||
} else if h, err = p.GetUint256(); err == nil {
|
||||
hdr, err := s.chain.GetHeader(h)
|
||||
if err == nil {
|
||||
rt, err = s.chain.GetStateRoot(hdr.Index)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, response.NewRPCError("Unknown state root.", "", err)
|
||||
}
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
|
||||
scriptHash, err := ps.Value(0).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
scriptHash = scriptHash.Reverse()
|
||||
|
||||
param, ok = ps.Value(1)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
key, err := param.GetBytesHex()
|
||||
key, err := ps.Value(1).GetBytesHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -743,30 +798,17 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
|
|||
var resultsErr *response.Error
|
||||
var results interface{}
|
||||
|
||||
if param0, ok := reqParams.Value(0); !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
} else if txHash, err := param0.GetUint256(); err != nil {
|
||||
if txHash, err := reqParams.Value(0).GetUint256(); err != nil {
|
||||
resultsErr = response.ErrInvalidParams
|
||||
} else if tx, height, err := s.chain.GetTransaction(txHash); err != nil {
|
||||
err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash)
|
||||
return nil, response.NewRPCError("Unknown transaction", err.Error(), err)
|
||||
} else if len(reqParams) >= 2 {
|
||||
} else if reqParams.Value(1).GetBoolean() {
|
||||
_header := s.chain.GetHeaderHash(int(height))
|
||||
header, err := s.chain.GetHeader(_header)
|
||||
if err != nil {
|
||||
resultsErr = response.NewInvalidParamsError(err.Error(), err)
|
||||
}
|
||||
|
||||
param1, _ := reqParams.Value(1)
|
||||
switch v := param1.Value.(type) {
|
||||
|
||||
case int, float64, bool, string:
|
||||
if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" {
|
||||
results = hex.EncodeToString(tx.Bytes())
|
||||
} else {
|
||||
results = result.NewTransactionOutputRaw(tx, header, s.chain)
|
||||
}
|
||||
default:
|
||||
} else {
|
||||
results = result.NewTransactionOutputRaw(tx, header, s.chain)
|
||||
}
|
||||
} else {
|
||||
|
@ -777,12 +819,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
|
|||
}
|
||||
|
||||
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
h, err := p.GetUint256()
|
||||
h, err := ps.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -796,22 +833,12 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response
|
|||
}
|
||||
|
||||
func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
h, err := p.GetUint256()
|
||||
h, err := ps.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
p, ok = ps.ValueWithType(1, request.NumberT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
num, err := p.GetInt()
|
||||
num, err := ps.ValueWithType(1, request.NumberT).GetInt()
|
||||
if err != nil || num < 0 {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -833,18 +860,15 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) {
|
|||
func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var results interface{}
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
} else if scriptHash, err := param.GetUint160FromHex(); err != nil {
|
||||
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
cs := s.chain.GetContractState(scriptHash)
|
||||
if cs != nil {
|
||||
results = result.NewContractState(cs)
|
||||
} else {
|
||||
cs := s.chain.GetContractState(scriptHash)
|
||||
if cs != nil {
|
||||
results = result.NewContractState(cs)
|
||||
} else {
|
||||
return nil, response.NewRPCError("Unknown contract", "", nil)
|
||||
}
|
||||
return nil, response.NewRPCError("Unknown contract", "", nil)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
@ -862,33 +886,31 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in
|
|||
var resultsErr *response.Error
|
||||
var results interface{}
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
} else if scriptHash, err := param.GetUint160FromAddress(); err != nil {
|
||||
param := reqParams.ValueWithType(0, request.StringT)
|
||||
scriptHash, err := param.GetUint160FromAddress()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
as := s.chain.GetAccountState(scriptHash)
|
||||
if as == nil {
|
||||
as = state.NewAccount(scriptHash)
|
||||
}
|
||||
if unspents {
|
||||
str, err := param.GetString()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
results = result.NewUnspents(as, s.chain, str)
|
||||
} else {
|
||||
as := s.chain.GetAccountState(scriptHash)
|
||||
if as == nil {
|
||||
as = state.NewAccount(scriptHash)
|
||||
}
|
||||
if unspents {
|
||||
str, err := param.GetString()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
results = result.NewUnspents(as, s.chain, str)
|
||||
} else {
|
||||
results = result.NewAccountState(as)
|
||||
}
|
||||
results = result.NewAccountState(as)
|
||||
}
|
||||
return results, resultsErr
|
||||
}
|
||||
|
||||
// getBlockSysFee returns the system fees of the block, based on the specified index.
|
||||
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||
if !ok {
|
||||
param := reqParams.ValueWithType(0, request.NumberT)
|
||||
if param == nil {
|
||||
return 0, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
|
@ -913,26 +935,12 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons
|
|||
|
||||
// getBlockHeader returns the corresponding block header information according to the specified script hash.
|
||||
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var verbose bool
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
hash, err := param.GetUint256()
|
||||
hash, err := reqParams.ValueWithType(0, request.StringT).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
param, ok = reqParams.ValueWithType(1, request.NumberT)
|
||||
if ok {
|
||||
v, err := param.GetInt()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
verbose = v != 0
|
||||
}
|
||||
|
||||
verbose := reqParams.Value(1).GetBoolean()
|
||||
h, err := s.chain.GetHeader(hash)
|
||||
if err != nil {
|
||||
return nil, response.NewRPCError("unknown block", "", nil)
|
||||
|
@ -952,11 +960,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons
|
|||
|
||||
// getUnclaimed returns unclaimed GAS amount of the specified address.
|
||||
func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
u, err := p.GetUint160FromAddress()
|
||||
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -997,19 +1001,11 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error)
|
|||
|
||||
// invoke implements the `invoke` RPC call.
|
||||
func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) {
|
||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
sliceP, ok := reqParams.ValueWithType(1, request.ArrayT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
slice, err := sliceP.GetArray()
|
||||
slice, err := reqParams.ValueWithType(1, request.ArrayT).GetArray()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -1022,11 +1018,7 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error)
|
|||
|
||||
// invokescript implements the `invokescript` RPC call.
|
||||
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) {
|
||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -1069,11 +1061,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke {
|
|||
|
||||
// submitBlock broadcasts a raw block over the NEO network.
|
||||
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
blockBytes, err := param.GetBytesHex()
|
||||
blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -1134,11 +1122,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res
|
|||
|
||||
// subscribe handles subscription requests from websocket clients.
|
||||
func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
|
||||
p, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
streamName, err := p.GetString()
|
||||
streamName, err := reqParams.Value(0).GetString()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -1148,8 +1132,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
|
|||
}
|
||||
// Optional filter.
|
||||
var filter interface{}
|
||||
p, ok = reqParams.Value(1)
|
||||
if ok {
|
||||
if p := reqParams.Value(1); p != nil {
|
||||
// It doesn't accept filters.
|
||||
if event == response.BlockEventID {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -1224,11 +1207,7 @@ func (s *Server) subscribeToChannel(event response.EventID) {
|
|||
|
||||
// unsubscribe handles unsubscription requests from websocket clients.
|
||||
func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
|
||||
p, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
id, err := p.GetInt()
|
||||
id, err := reqParams.Value(0).GetInt()
|
||||
if err != nil || id < 0 {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@ import (
|
|||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||
|
@ -213,6 +215,54 @@ var rpcTestCases = map[string][]rpcTestCase{
|
|||
},
|
||||
},
|
||||
},
|
||||
"getproof": {
|
||||
{
|
||||
name: "no params",
|
||||
params: `[]`,
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
name: "invalid root",
|
||||
params: `["0xabcdef"]`,
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
name: "invalid contract",
|
||||
params: `["0000000000000000000000000000000000000000000000000000000000000000", "0xabcdef"]`,
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
name: "invalid key",
|
||||
params: `["0000000000000000000000000000000000000000000000000000000000000000", "` + testContractHash + `", "notahex"]`,
|
||||
fail: true,
|
||||
},
|
||||
},
|
||||
"getstateheight": {
|
||||
{
|
||||
name: "positive",
|
||||
params: `[]`,
|
||||
result: func(_ *executor) interface{} { return new(result.StateHeight) },
|
||||
check: func(t *testing.T, e *executor, res interface{}) {
|
||||
sh, ok := res.(*result.StateHeight)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, e.chain.BlockHeight(), sh.BlockHeight)
|
||||
require.Equal(t, e.chain.StateHeight(), sh.StateHeight)
|
||||
},
|
||||
},
|
||||
},
|
||||
"getstateroot": {
|
||||
{
|
||||
name: "no params",
|
||||
params: `[]`,
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hash",
|
||||
params: `["0x1234567890"]`,
|
||||
fail: true,
|
||||
},
|
||||
},
|
||||
"getstorage": {
|
||||
{
|
||||
name: "positive",
|
||||
|
@ -928,6 +978,52 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []
|
|||
})
|
||||
}
|
||||
|
||||
t.Run("getproof", func(t *testing.T) {
|
||||
r, err := chain.GetStateRoot(205)
|
||||
require.NoError(t, err)
|
||||
|
||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getproof", "params": ["%s", "%s", "%x"]}`,
|
||||
r.Root.StringLE(), testContractHash, []byte("testkey"))
|
||||
fmt.Println(rpc)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
fmt.Println(string(body))
|
||||
rawRes := checkErrGetResult(t, body, false)
|
||||
res := new(result.GetProof)
|
||||
require.NoError(t, json.Unmarshal(rawRes, res))
|
||||
require.True(t, res.Success)
|
||||
h, _ := hex.DecodeString(testContractHash)
|
||||
skey := append(h, []byte("testkey")...)
|
||||
require.Equal(t, mpt.ToNeoStorageKey(skey), res.Result.Key)
|
||||
require.True(t, len(res.Result.Proof) > 0)
|
||||
|
||||
rpc = fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "verifyproof", "params": ["%s", "%s"]}`,
|
||||
r.Root.StringLE(), res.Result.String())
|
||||
body = doRPCCall(rpc, httpSrv.URL, t)
|
||||
rawRes = checkErrGetResult(t, body, false)
|
||||
vp := new(result.VerifyProof)
|
||||
require.NoError(t, json.Unmarshal(rawRes, vp))
|
||||
require.Equal(t, []byte("testvalue"), vp.Value)
|
||||
})
|
||||
|
||||
t.Run("getstateroot", func(t *testing.T) {
|
||||
testRoot := func(t *testing.T, p string) {
|
||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getstateroot", "params": [%s]}`, p)
|
||||
fmt.Println(rpc)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
rawRes := checkErrGetResult(t, body, false)
|
||||
|
||||
res := new(state.MPTRootState)
|
||||
require.NoError(t, json.Unmarshal(rawRes, res))
|
||||
require.NotEqual(t, util.Uint256{}, res.Root) // be sure this test uses valid height
|
||||
|
||||
expected, err := e.chain.GetStateRoot(205)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, res)
|
||||
}
|
||||
t.Run("ByHeight", func(t *testing.T) { testRoot(t, strconv.FormatInt(205, 10)) })
|
||||
t.Run("ByHash", func(t *testing.T) { testRoot(t, `"`+chain.GetHeaderHash(205).StringLE()+`"`) })
|
||||
})
|
||||
|
||||
t.Run("getrawtransaction", func(t *testing.T) {
|
||||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||
TXHash := block.Transactions[1].Hash()
|
||||
|
|
Loading…
Reference in a new issue