diff --git a/cli/server/dump.go b/cli/server/dump.go index ead0b3b33..51ab2399d 100644 --- a/cli/server/dump.go +++ b/cli/server/dump.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -33,35 +34,7 @@ func toNeoStorageKey(key []byte) []byte { if len(key) < util.Uint160Size { panic("invalid key in storage") } - - var nkey []byte - for i := util.Uint160Size - 1; i >= 0; i-- { - nkey = append(nkey, key[i]) - } - - key = key[util.Uint160Size:] - - index := 0 - remain := len(key) - for remain >= 16 { - nkey = append(nkey, key[index:index+16]...) - nkey = append(nkey, 0) - index += 16 - remain -= 16 - } - - if remain > 0 { - nkey = append(nkey, key[index:]...) - } - - padding := 16 - remain - for i := 0; i < padding; i++ { - nkey = append(nkey, 0) - } - - nkey = append(nkey, byte(padding)) - - return nkey + return mpt.ToNeoStorageKey(key) } // batchToMap converts batch to a map so that JSON is compatible diff --git a/config/protocol.testnet.yml b/config/protocol.testnet.yml index 711485d26..8ccb17005 100644 --- a/config/protocol.testnet.yml +++ b/config/protocol.testnet.yml @@ -2,6 +2,8 @@ ProtocolConfiguration: Magic: 1953787457 AddressVersion: 23 SecondsPerBlock: 15 + EnableStateRoot: true + StateRootEnableIndex: 4380100 LowPriorityThreshold: 0.000 MemPoolSize: 50000 StandbyValidators: diff --git a/config/protocol.unit_testnet.yml b/config/protocol.unit_testnet.yml index c21e1c3f0..b9ca4a761 100644 --- a/config/protocol.unit_testnet.yml +++ b/config/protocol.unit_testnet.yml @@ -2,6 +2,7 @@ ProtocolConfiguration: Magic: 56753 AddressVersion: 23 SecondsPerBlock: 15 + EnableStateRoot: true LowPriorityThreshold: 0.000 MemPoolSize: 50000 StandbyValidators: diff --git a/go.mod b/go.mod index 68b323af0..696d28e6a 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,13 @@ module github.com/nspcc-dev/neo-go require ( github.com/Workiva/go-datastructures v1.0.50 github.com/alicebob/miniredis v2.5.0+incompatible + github.com/btcsuite/btcd v0.20.1-beta github.com/dgraph-io/badger/v2 v2.0.3 github.com/go-redis/redis v6.10.2+incompatible github.com/go-yaml/yaml v2.1.0+incompatible github.com/gorilla/websocket v1.4.2 github.com/mr-tron/base58 v1.1.2 - github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac + github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e github.com/nspcc-dev/rfc6979 v0.2.0 github.com/pkg/errors v0.8.1 github.com/prometheus/client_golang v1.2.1 diff --git a/go.sum b/go.sum index f739edfcc..3d3345df9 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg= github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8= github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db/go.mod h1:rB3B4rKii8V21ydCbIzH5hZiCQE7f5E9SzUb/ZZx530= +github.com/aead/siphash v1.0.1 h1:FwHfE/T45KPKYuuSAKyyvE+oPWcaQ+CUmFW0bPlM+kg= +github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -28,6 +30,22 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/btcsuite/btcd v0.20.1-beta h1:Ik4hyJqN8Jfyv3S4AGBOmyouMsYE3EdYODkMbQjwPGw= +github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= +github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f h1:bAs4lUbRJpnnkd9VhRV3jjAVU7DJVjMaK+IsvSeZvFo= +github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= +github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d h1:yJzD/yFppdVCf6ApMkVy8cUxV0XrxdP9rVf6D87/Mng= +github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= +github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd h1:R/opQEbFEy9JGkIguV40SvRY1uliPX8ifOvi6ICsFCw= +github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg= +github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd h1:qdGvebPBDuYDPGi1WCPjy1tGyMpmDK8IEapSsszn7HE= +github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY= +github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723 h1:ZA/jbKoGcVAnER6pCHPEkGdZOV7U1oLUedErBHCUMs0= +github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc= +github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 h1:R8vQdOQdZ9Y3SkEwmHoWBmX1DNXhXZqlTpq6s4tyJGc= +github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= +github.com/btcsuite/winsvc v1.0.0 h1:J9B4L7e3oqhXOcm+2IuNApwzQec85lE+QaikUcCs+dk= +github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.0 h1:yTUvW7Vhb89inJ+8irsUqiWjh8iT6sQPZiQzI6ReGkA= @@ -41,6 +59,7 @@ github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= +github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -91,9 +110,15 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89 h1:12K8AlpT0/6QUXSfV0yi4Q0jkbq8NDtIKFtF61AoqV0= +github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jrick/logrotate v1.0.0 h1:lQ1bL/n9mBNeIXoTUoYRlK4dHuNJVofX9oWqBtPnSzI= +github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23 h1:FOOIBWrEkLgmlgGfMuZT83xIwfPDxEI2OHu6xUmJMFE= +github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -131,8 +156,8 @@ github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a h1:ajvxgEe9qY4vvoSm github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a/go.mod h1:/YFK+XOxxg0Bfm6P92lY5eDSLYfp06XOdL8KAVgXjVk= github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1 h1:yEx9WznS+rjE0jl0dLujCxuZSIb+UTjF+005TJu/nNI= github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1/go.mod h1:O0qtn62prQSqizzoagHmuuKoz8QMkU3SzBoKdEvm3aQ= -github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac h1:cXPgsp4avJ7cR1nPRdpFRHmWoMSRZ41FSvlNjpsyTiA= -github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac/go.mod h1:1FYQXSbb6/9HQIkoF8XO7W/S8N7AZRkBsgwbcXRvk0E= +github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e h1:QOT9slflIkEKb5wY0ZUC0dCmCgoqGlhOAh9+xWMIxfg= +github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e/go.mod h1:1FYQXSbb6/9HQIkoF8XO7W/S8N7AZRkBsgwbcXRvk0E= github.com/nspcc-dev/neo-go v0.73.1-pre.0.20200303142215-f5a1b928ce09/go.mod h1:pPYwPZ2ks+uMnlRLUyXOpLieaDQSEaf4NM3zHVbRjmg= github.com/nspcc-dev/neofs-crypto v0.2.0 h1:ftN+59WqxSWz/RCgXYOfhmltOOqU+udsNQSvN6wkFck= github.com/nspcc-dev/neofs-crypto v0.2.0/go.mod h1:F/96fUzPM3wR+UGsPi3faVNmFlA9KAEAUQR7dMxZmNA= @@ -144,10 +169,12 @@ github.com/nspcc-dev/rfc6979 v0.2.0 h1:3e1WNxrN60/6N0DW7+UYisLeZJyfqZTNOjeV/toYv github.com/nspcc-dev/rfc6979 v0.2.0/go.mod h1:exhIh1PdpDC5vQmyEsGvc4YDM/lyQp/452QxGq/UEso= github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.3 h1:OoxbjfXVZyod1fmWYhI7SEyaD8B00ynP3T+D5GiyHOY= github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I= github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= @@ -212,6 +239,7 @@ go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= diff --git a/pkg/compiler/syscall.go b/pkg/compiler/syscall.go index 086112040..e8c372020 100644 --- a/pkg/compiler/syscall.go +++ b/pkg/compiler/syscall.go @@ -11,6 +11,10 @@ var syscalls = map[string]map[string]string{ "GetUsage": "Neo.Attribute.GetUsage", "GetData": "Neo.Attribute.GetData", }, + "crypto": { + "Secp256k1Recover": "Neo.Cryptography.Secp256k1Recover", + "Secp256r1Recover": "Neo.Cryptography.Secp256r1Recover", + }, "enumerator": { "Concat": "Neo.Enumerator.Concat", "Create": "Neo.Enumerator.Create", diff --git a/pkg/config/protocol_config.go b/pkg/config/protocol_config.go index c84de79a2..dff592038 100644 --- a/pkg/config/protocol_config.go +++ b/pkg/config/protocol_config.go @@ -20,6 +20,8 @@ const ( type ( ProtocolConfiguration struct { AddressVersion byte `yaml:"AddressVersion"` + // EnableStateRoot specifies if exchange of state roots should be enabled. + EnableStateRoot bool `yaml:"EnableStateRoot"` // FeePerExtraByte sets the expected per-byte fee for // transactions exceeding the MaxFreeTransactionSize. FeePerExtraByte float64 `yaml:"FeePerExtraByte"` @@ -34,11 +36,13 @@ type ( MaxFreeTransactionsPerBlock int `yaml:"MaxFreeTransactionsPerBlock"` MemPoolSize int `yaml:"MemPoolSize"` // SaveStorageBatch enables storage batch saving before every persist. - SaveStorageBatch bool `yaml:"SaveStorageBatch"` - SecondsPerBlock int `yaml:"SecondsPerBlock"` - SeedList []string `yaml:"SeedList"` - StandbyValidators []string `yaml:"StandbyValidators"` - SystemFee SystemFee `yaml:"SystemFee"` + SaveStorageBatch bool `yaml:"SaveStorageBatch"` + SecondsPerBlock int `yaml:"SecondsPerBlock"` + SeedList []string `yaml:"SeedList"` + StandbyValidators []string `yaml:"StandbyValidators"` + // StateRootEnableIndex specifies starting height for state root calculations and exchange. + StateRootEnableIndex uint32 `yaml:"StateRootEnableIndex"` + SystemFee SystemFee `yaml:"SystemFee"` // Whether to verify received blocks. VerifyBlocks bool `yaml:"VerifyBlocks"` // Whether to verify transactions in received blocks. diff --git a/pkg/consensus/commit.go b/pkg/consensus/commit.go index 492a1a156..a000b7abf 100644 --- a/pkg/consensus/commit.go +++ b/pkg/consensus/commit.go @@ -8,6 +8,9 @@ import ( // commit represents dBFT Commit message. type commit struct { signature [signatureSize]byte + stateSig [signatureSize]byte + + stateRootEnabled bool } // signatureSize is an rfc6989 signature size in bytes @@ -19,11 +22,17 @@ var _ payload.Commit = (*commit)(nil) // EncodeBinary implements io.Serializable interface. func (c *commit) EncodeBinary(w *io.BinWriter) { w.WriteBytes(c.signature[:]) + if c.stateRootEnabled { + w.WriteBytes(c.stateSig[:]) + } } // DecodeBinary implements io.Serializable interface. func (c *commit) DecodeBinary(r *io.BinReader) { r.ReadBytes(c.signature[:]) + if c.stateRootEnabled { + r.ReadBytes(c.stateSig[:]) + } } // Signature implements payload.Commit interface. diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index fbb84a3e1..fa135d362 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -2,6 +2,7 @@ package consensus import ( "errors" + "fmt" "math/rand" "sort" "time" @@ -13,7 +14,9 @@ import ( "github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/neo-go/pkg/core" coreb "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/cache" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -49,9 +52,9 @@ type service struct { log *zap.Logger // cache is a fifo cache which stores recent payloads. - cache *relayCache + cache *cache.HashCache // txx is a fifo cache which stores miner transactions. - txx *relayCache + txx *cache.HashCache dbft *dbft.DBFT // messages and transactions are channels needed to process // everything in single thread. @@ -70,7 +73,7 @@ type Config struct { Logger *zap.Logger // Broadcast is a callback which is called to notify server // about new consensus payload to sent. - Broadcast func(p *Payload) + Broadcast func(cache.Hashable) // Chain is a core.Blockchainer instance. Chain core.Blockchainer // RequestTx is a callback to which will be called @@ -96,8 +99,8 @@ func NewService(cfg Config) (Service, error) { Config: cfg, log: cfg.Logger, - cache: newFIFOCache(cacheMaxCapacity), - txx: newFIFOCache(cacheMaxCapacity), + cache: cache.NewFIFOCache(cacheMaxCapacity), + txx: cache.NewFIFOCache(cacheMaxCapacity), messages: make(chan Payload, 100), transactions: make(chan *transaction.Transaction, 100), @@ -120,7 +123,6 @@ func NewService(cfg Config) (Service, error) { dbft.WithLogger(srv.log), dbft.WithSecondsPerBlock(cfg.TimePerBlock), dbft.WithGetKeyPair(srv.getKeyPair), - dbft.WithTxPerBlock(10000), dbft.WithRequestTx(cfg.RequestTx), dbft.WithGetTx(srv.getTx), dbft.WithGetVerified(srv.getVerifiedTx), @@ -135,13 +137,14 @@ func NewService(cfg Config) (Service, error) { dbft.WithGetValidators(srv.getValidators), dbft.WithGetConsensusAddress(srv.getConsensusAddress), - dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { p := new(Payload); p.message = &message{}; return p }), - dbft.WithNewPrepareRequest(func() payload.PrepareRequest { return new(prepareRequest) }), + dbft.WithNewConsensusPayload(srv.newPayload), + dbft.WithNewPrepareRequest(srv.newPrepareRequest), dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }), dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }), - dbft.WithNewCommit(func() payload.Commit { return new(commit) }), + dbft.WithNewCommit(srv.newCommit), dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }), dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }), + dbft.WithVerifyPrepareRequest(srv.verifyRequest), ) if srv.dbft == nil { @@ -210,6 +213,53 @@ func (s *service) eventLoop() { } } +func (s *service) newPayload() payload.ConsensusPayload { + return &Payload{ + message: &message{ + stateRootEnabled: s.stateRootEnabled(), + }, + } +} + +// stateRootEnabled checks if state root feature is enabled on current height. +// It should be called only from dbft callbacks and is not protected by any mutex. +func (s *service) stateRootEnabled() bool { + return s.Chain.GetConfig().EnableStateRoot +} + +func (s *service) newPrepareRequest() payload.PrepareRequest { + if !s.stateRootEnabled() { + return new(prepareRequest) + } + sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight()) + if err == nil { + return &prepareRequest{ + stateRootEnabled: true, + proposalStateRoot: sr.MPTRootBase, + } + } + return &prepareRequest{stateRootEnabled: true} +} + +func (s *service) newCommit() payload.Commit { + if !s.stateRootEnabled() { + return new(commit) + } + c := &commit{stateRootEnabled: true} + for _, p := range s.dbft.Context.PreparationPayloads { + if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType { + pr := p.GetPrepareRequest().(*prepareRequest) + data := pr.proposalStateRoot.GetSignedPart() + sign, err := s.dbft.Priv.Sign(data) + if err == nil { + copy(c.stateSig[:], sign) + } + break + } + } + return c +} + func (s *service) validatePayload(p *Payload) bool { validators := s.getValidators() if int(p.validatorIndex) >= len(validators) { @@ -262,8 +312,8 @@ func (s *service) OnPayload(cp *Payload) { // decode payload data into message if cp.message == nil { - if err := cp.decodeData(); err != nil { - log.Debug("can't decode payload data") + if err := cp.decodeData(s.stateRootEnabled()); err != nil { + log.Debug("can't decode payload data", zap.Error(err)) return } } @@ -340,6 +390,21 @@ func (s *service) verifyBlock(b block.Block) bool { return true } +func (s *service) verifyRequest(p payload.ConsensusPayload) error { + if !s.stateRootEnabled() { + return nil + } + r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) + if err != nil { + return fmt.Errorf("can't get local state root: %v", err) + } + rb := &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot + if !r.Equals(rb) { + return errors.New("state root mismatch") + } + return nil +} + func (s *service) processBlock(b block.Block) { bb := &b.(*neoBlock).Block bb.Script = *(s.getBlockWitness(bb)) @@ -351,16 +416,36 @@ func (s *service) processBlock(b block.Block) { s.log.Warn("error on add block", zap.Error(err)) } } + + var rb *state.MPTRootBase + for _, p := range s.dbft.PreparationPayloads { + if p != nil && p.Type() == payload.PrepareRequestType { + rb = &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot + } + } + w := s.getWitness(func(p payload.Commit) []byte { return p.(*commit).stateSig[:] }) + r := &state.MPTRoot{ + MPTRootBase: *rb, + Witness: w, + } + if err := s.Chain.AddStateRoot(r); err != nil { + s.log.Warn("errors while adding state root", zap.Error(err)) + } + s.Broadcast(r) } -func (s *service) getBlockWitness(b *coreb.Block) *transaction.Witness { +func (s *service) getBlockWitness(_ *coreb.Block) *transaction.Witness { + return s.getWitness(func(p payload.Commit) []byte { return p.Signature() }) +} + +func (s *service) getWitness(f func(p payload.Commit) []byte) *transaction.Witness { dctx := s.dbft.Context pubs := convertKeys(dctx.Validators) sigs := make(map[*keys.PublicKey][]byte) for i := range pubs { if p := dctx.CommitPayloads[i]; p != nil && p.ViewNumber() == dctx.ViewNumber { - sigs[pubs[i]] = p.GetCommit().Signature() + sigs[pubs[i]] = f(p.GetCommit()) } } @@ -397,7 +482,7 @@ func (s *service) getBlock(h util.Uint256) block.Block { return &neoBlock{Block: *b} } -func (s *service) getVerifiedTx(count int) []block.Transaction { +func (s *service) getVerifiedTx() []block.Transaction { pool := s.Config.Chain.GetMemPool() var txx []mempool.TxWithFee diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 285971622..e15c6ccf7 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -7,6 +7,7 @@ import ( "github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/cache" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" @@ -25,7 +26,7 @@ func TestNewService(t *testing.T) { require.NoError(t, srv.Chain.PoolTx(tx)) var txx []block.Transaction - require.NotPanics(t, func() { txx = srv.getVerifiedTx(1) }) + require.NotPanics(t, func() { txx = srv.getVerifiedTx() }) require.Len(t, txx, 2) require.Equal(t, tx, txx[1]) srv.Chain.Close() @@ -58,7 +59,7 @@ func TestService_GetVerified(t *testing.T) { srv.dbft.ViewNumber = 1 t.Run("new transactions will be proposed in case of failure", func(t *testing.T) { - txx := srv.getVerifiedTx(10) + txx := srv.getVerifiedTx() require.Equal(t, 2, len(txx), "there is only 1 tx in mempool") require.Equal(t, txs[3], txx[1]) }) @@ -68,7 +69,7 @@ func TestService_GetVerified(t *testing.T) { require.NoError(t, srv.Chain.PoolTx(tx)) } - txx := srv.getVerifiedTx(10) + txx := srv.getVerifiedTx() require.Contains(t, txx, txs[0]) require.Contains(t, txx, txs[1]) require.NotContains(t, txx, txs[2]) @@ -182,7 +183,7 @@ func shouldNotReceive(t *testing.T, ch chan Payload) { func newTestService(t *testing.T) *service { srv, err := NewService(Config{ Logger: zaptest.NewLogger(t), - Broadcast: func(*Payload) {}, + Broadcast: func(cache.Hashable) {}, Chain: newTestChain(t), RequestTx: func(...util.Uint256) {}, Wallet: &wallet.Config{ diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 925125fc2..c74ac9515 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -22,6 +22,8 @@ type ( Type messageType ViewNumber byte + stateRootEnabled bool + payload io.Serializable } @@ -283,15 +285,21 @@ func (m *message) DecodeBinary(r *io.BinReader) { cv.newViewNumber = m.ViewNumber + 1 m.payload = cv case prepareRequestType: - m.payload = new(prepareRequest) + m.payload = &prepareRequest{ + stateRootEnabled: m.stateRootEnabled, + } case prepareResponseType: m.payload = new(prepareResponse) case commitType: - m.payload = new(commit) + m.payload = &commit{ + stateRootEnabled: m.stateRootEnabled, + } case recoveryRequestType: m.payload = new(recoveryRequest) case recoveryMessageType: - m.payload = new(recoveryMessage) + m.payload = &recoveryMessage{ + stateRootEnabled: m.stateRootEnabled, + } default: r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type)) return @@ -319,9 +327,9 @@ func (t messageType) String() string { } } -// decode data of payload into it's message -func (p *Payload) decodeData() error { - m := new(message) +// decodeData decodes data of payload into it's message. +func (p *Payload) decodeData(stateRootEnabled bool) error { + m := &message{stateRootEnabled: stateRootEnabled} br := io.NewBinReaderFromBuf(p.data) m.DecodeBinary(br) if br.Err != nil { diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index c07ff651a..f060ede5d 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -94,13 +94,13 @@ func TestConsensusPayload_Serializable(t *testing.T) { // message is nil after decoding as we didn't yet call decodeData require.Nil(t, actual.message) // message should now be decoded from actual.data byte array - assert.NoError(t, actual.decodeData()) + assert.NoError(t, actual.decodeData(false)) require.Equal(t, p, actual) data = p.MarshalUnsigned() pu := new(Payload) require.NoError(t, pu.UnmarshalUnsigned(data)) - assert.NoError(t, pu.decodeData()) + assert.NoError(t, pu.decodeData(false)) p.Witness = transaction.Witness{} require.Equal(t, p, pu) @@ -144,14 +144,14 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { p := new(Payload) require.NoError(t, testserdes.DecodeBinary(buf, p)) // decode `data` into `message` - assert.NoError(t, p.decodeData()) + assert.NoError(t, p.decodeData(false)) require.Equal(t, expected, p) // invalid type buf[typeIndex] = 0xFF actual := new(Payload) require.NoError(t, testserdes.DecodeBinary(buf, actual)) - require.Error(t, actual.decodeData()) + require.Error(t, actual.decodeData(false)) // invalid format buf[delimeterIndex] = 0 @@ -165,9 +165,16 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { require.Error(t, testserdes.DecodeBinary(buf, new(Payload))) } +func testEncodeDecode(srEnabled bool, mt messageType, actual io.Serializable) func(t *testing.T) { + return func(t *testing.T) { + expected := randomMessage(t, mt, srEnabled) + testserdes.EncodeDecodeBinary(t, expected, actual) + } +} + func TestCommit_Serializable(t *testing.T) { - c := randomMessage(t, commitType) - testserdes.EncodeDecodeBinary(t, c, new(commit)) + t.Run("WithStateRoot", testEncodeDecode(true, commitType, &commit{stateRootEnabled: true})) + t.Run("NoStateRoot", testEncodeDecode(false, commitType, &commit{stateRootEnabled: false})) } func TestPrepareResponse_Serializable(t *testing.T) { @@ -176,8 +183,8 @@ func TestPrepareResponse_Serializable(t *testing.T) { } func TestPrepareRequest_Serializable(t *testing.T) { - req := randomMessage(t, prepareRequestType) - testserdes.EncodeDecodeBinary(t, req, new(prepareRequest)) + t.Run("WithStateRoot", testEncodeDecode(true, prepareRequestType, &prepareRequest{stateRootEnabled: true})) + t.Run("NoStateRoot", testEncodeDecode(false, prepareRequestType, &prepareRequest{stateRootEnabled: false})) } func TestRecoveryRequest_Serializable(t *testing.T) { @@ -186,8 +193,8 @@ func TestRecoveryRequest_Serializable(t *testing.T) { } func TestRecoveryMessage_Serializable(t *testing.T) { - msg := randomMessage(t, recoveryMessageType) - testserdes.EncodeDecodeBinary(t, msg, new(recoveryMessage)) + t.Run("WithStateRoot", testEncodeDecode(true, recoveryMessageType, &recoveryMessage{stateRootEnabled: true})) + t.Run("NoStateRoot", testEncodeDecode(false, recoveryMessageType, &recoveryMessage{stateRootEnabled: false})) } func randomPayload(t *testing.T, mt messageType) *Payload { @@ -215,31 +222,35 @@ func randomPayload(t *testing.T, mt messageType) *Payload { return p } -func randomMessage(t *testing.T, mt messageType) io.Serializable { +func randomMessage(t *testing.T, mt messageType, srEnabled ...bool) io.Serializable { switch mt { case changeViewType: return &changeView{ timestamp: rand.Uint32(), } case prepareRequestType: - return randomPrepareRequest(t) + return randomPrepareRequest(t, srEnabled...) case prepareResponseType: return &prepareResponse{preparationHash: random.Uint256()} case commitType: var c commit random.Fill(c.signature[:]) + if len(srEnabled) > 0 && srEnabled[0] { + c.stateRootEnabled = true + random.Fill(c.stateSig[:]) + } return &c case recoveryRequestType: return &recoveryRequest{timestamp: rand.Uint32()} case recoveryMessageType: - return randomRecoveryMessage(t) + return randomRecoveryMessage(t, srEnabled...) default: require.Fail(t, "invalid type") return nil } } -func randomPrepareRequest(t *testing.T) *prepareRequest { +func randomPrepareRequest(t *testing.T, srEnabled ...bool) *prepareRequest { const txCount = 3 req := &prepareRequest{ @@ -255,15 +266,22 @@ func randomPrepareRequest(t *testing.T) *prepareRequest { } req.nextConsensus = random.Uint160() + if len(srEnabled) > 0 && srEnabled[0] { + req.stateRootEnabled = true + req.proposalStateRoot.Index = rand.Uint32() + req.proposalStateRoot.PrevHash = random.Uint256() + req.proposalStateRoot.Root = random.Uint256() + } + return req } -func randomRecoveryMessage(t *testing.T) *recoveryMessage { - result := randomMessage(t, prepareRequestType) +func randomRecoveryMessage(t *testing.T, srEnabled ...bool) *recoveryMessage { + result := randomMessage(t, prepareRequestType, srEnabled...) require.IsType(t, (*prepareRequest)(nil), result) prepReq := result.(*prepareRequest) - return &recoveryMessage{ + rec := &recoveryMessage{ preparationPayloads: []*preparationCompact{ { ValidatorIndex: 1, @@ -297,6 +315,15 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage { payload: prepReq, }, } + if len(srEnabled) > 0 && srEnabled[0] { + rec.stateRootEnabled = true + rec.prepareRequest.stateRootEnabled = true + for _, c := range rec.commitPayloads { + c.stateRootEnabled = true + random.Fill(c.StateSignature[:]) + } + } + return rec } func TestPayload_Sign(t *testing.T) { diff --git a/pkg/consensus/prepare_request.go b/pkg/consensus/prepare_request.go index f40b74ab0..8a28f3a5c 100644 --- a/pkg/consensus/prepare_request.go +++ b/pkg/consensus/prepare_request.go @@ -2,6 +2,7 @@ package consensus import ( "github.com/nspcc-dev/dbft/payload" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -14,6 +15,9 @@ type prepareRequest struct { transactionHashes []util.Uint256 minerTx transaction.Transaction nextConsensus util.Uint160 + proposalStateRoot state.MPTRootBase + + stateRootEnabled bool } var _ payload.PrepareRequest = (*prepareRequest)(nil) @@ -25,6 +29,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) { w.WriteBytes(p.nextConsensus[:]) w.WriteArray(p.transactionHashes) p.minerTx.EncodeBinary(w) + if p.stateRootEnabled { + p.proposalStateRoot.EncodeBinary(w) + } } // DecodeBinary implements io.Serializable interface. @@ -34,6 +41,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) { r.ReadBytes(p.nextConsensus[:]) r.ReadArray(&p.transactionHashes) p.minerTx.DecodeBinary(r) + if p.stateRootEnabled { + p.proposalStateRoot.DecodeBinary(r) + } } // Timestamp implements payload.PrepareRequest interface. diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index 030db04ab..17c7601f8 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -3,6 +3,7 @@ package consensus import ( "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/payload" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/pkg/errors" @@ -16,6 +17,8 @@ type ( commitPayloads []*commitCompact changeViewPayloads []*changeViewCompact prepareRequest *message + + stateRootEnabled bool } changeViewCompact struct { @@ -29,7 +32,10 @@ type ( ViewNumber byte ValidatorIndex uint16 Signature [signatureSize]byte + StateSignature [signatureSize]byte InvocationScript []byte + + stateRootEnabled bool } preparationCompact struct { @@ -46,7 +52,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { var hasReq = r.ReadBool() if hasReq { - m.prepareRequest = new(message) + m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled} m.prepareRequest.DecodeBinary(r) if r.Err == nil && m.prepareRequest.Type != prepareRequestType { r.Err = errors.New("recovery message PrepareRequest has wrong type") @@ -67,7 +73,16 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { } r.ReadArray(&m.preparationPayloads) - r.ReadArray(&m.commitPayloads) + lu := r.ReadVarUint() + if lu > state.MaxValidatorsVoted { + r.Err = errors.New("too many preparation payloads") + return + } + m.commitPayloads = make([]*commitCompact, lu) + for i := uint64(0); i < lu; i++ { + m.commitPayloads[i] = &commitCompact{stateRootEnabled: m.stateRootEnabled} + m.commitPayloads[i].DecodeBinary(r) + } } // EncodeBinary implements io.Serializable interface. @@ -96,7 +111,7 @@ func (p *changeViewCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() p.OriginalViewNumber = r.ReadB() p.Timestamp = r.ReadU32LE() - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -112,7 +127,10 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) { p.ViewNumber = r.ReadB() p.ValidatorIndex = r.ReadU16LE() r.ReadBytes(p.Signature[:]) - p.InvocationScript = r.ReadVarBytes() + if p.stateRootEnabled { + r.ReadBytes(p.StateSignature[:]) + } + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -120,13 +138,16 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) { w.WriteB(p.ViewNumber) w.WriteU16LE(p.ValidatorIndex) w.WriteBytes(p.Signature[:]) + if p.stateRootEnabled { + w.WriteBytes(p.StateSignature[:]) + } w.WriteVarBytes(p.InvocationScript) } // DecodeBinary implements io.Serializable interface. func (p *preparationCompact) DecodeBinary(r *io.BinReader) { p.ValidatorIndex = r.ReadU16LE() - p.InvocationScript = r.ReadVarBytes() + p.InvocationScript = r.ReadVarBytes(1024) } // EncodeBinary implements io.Serializable interface. @@ -143,6 +164,8 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) { Type: prepareRequestType, ViewNumber: p.ViewNumber(), payload: p.GetPrepareRequest().(*prepareRequest), + + stateRootEnabled: m.stateRootEnabled, } h := p.Hash() m.preparationHash = &h @@ -169,9 +192,11 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) { }) case payload.CommitType: m.commitPayloads = append(m.commitPayloads, &commitCompact{ + stateRootEnabled: m.stateRootEnabled, ValidatorIndex: p.ValidatorIndex(), ViewNumber: p.ViewNumber(), Signature: p.GetCommit().(*commit).signature, + StateSignature: p.GetCommit().(*commit).stateSig, InvocationScript: p.(*Payload).Witness.InvocationScript, }) } @@ -234,6 +259,7 @@ func (m *recoveryMessage) GetChangeViews(p payload.ConsensusPayload, validators newViewNumber: cv.OriginalViewNumber + 1, timestamp: cv.Timestamp, }) + c.message.ViewNumber = cv.OriginalViewNumber c.SetValidatorIndex(cv.ValidatorIndex) c.Witness.InvocationScript = cv.InvocationScript c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators) @@ -249,7 +275,12 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr ps := make([]payload.ConsensusPayload, len(m.commitPayloads)) for i, c := range m.commitPayloads { - cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature}) + cc := fromPayload(commitType, p.(*Payload), &commit{ + signature: c.Signature, + stateSig: c.StateSignature, + + stateRootEnabled: m.stateRootEnabled, + }) cc.SetValidatorIndex(c.ValidatorIndex) cc.Witness.InvocationScript = c.InvocationScript cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators) @@ -289,6 +320,8 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload { Type: t, ViewNumber: recovery.message.ViewNumber, payload: p, + + stateRootEnabled: recovery.stateRootEnabled, }, version: recovery.Version(), prevHash: recovery.PrevHash(), diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index aa452bf86..fc005077b 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -13,9 +13,11 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -229,6 +231,11 @@ func (bc *Blockchain) init() error { } bc.blockHeight = bHeight bc.persistedHeight = bHeight + if bc.config.EnableStateRoot { + if err = bc.dao.InitMPT(bHeight); err != nil { + return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + } + } hashes, err := bc.dao.GetHeaderHashes() if err != nil { @@ -551,6 +558,23 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 { return sf } +// GetStateProof returns proof of having key in the MPT with the specified root. +func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) { + if !bc.config.EnableStateRoot { + return nil, errors.New("state root feature is not enabled") + } + tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store)) + return tr.GetProof(key) +} + +// GetStateRoot returns state root for a given height. +func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + if !bc.config.EnableStateRoot { + return nil, errors.New("state root feature is not enabled") + } + return bc.dao.GetStateRoot(height) +} + // TODO: storeBlock needs some more love, its implemented as in the original // project. This for the sake of development speed and understanding of what // is happening here, quite allot as you can see :). If things are wired together @@ -819,6 +843,28 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } + if bc.config.EnableStateRoot { + root := bc.dao.MPT.StateRoot() + var prevHash util.Uint256 + if block.Index > 0 { + prev, err := bc.dao.GetStateRoot(block.Index - 1) + if err != nil { + return errors.WithMessagef(err, "can't get previous state root") + } + prevHash = hash.DoubleSha256(prev.GetSignedPart()) + } + err := bc.AddStateRoot(&state.MPTRoot{ + MPTRootBase: state.MPTRootBase{ + Index: block.Index, + PrevHash: prevHash, + Root: root, + }, + }) + if err != nil { + return err + } + } + if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() } @@ -829,6 +875,15 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { bc.lock.Unlock() return err } + if bc.config.EnableStateRoot { + bc.dao.MPT.Flush() + // Every persist cycle we also compact our in-memory MPT. + persistedHeight := atomic.LoadUint32(&bc.persistedHeight) + if persistedHeight == block.Index-1 { + // 10 is good and roughly estimated to fit remaining trie into 1M of memory. + bc.dao.MPT.Collapse(10) + } + } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant) @@ -1492,12 +1547,13 @@ func (bc *Blockchain) ApplyPolicyToTxSet(txes []mempool.TxWithFee) []mempool.TxW txes = txes[:bc.config.MaxTransactionsPerBlock] } maxFree := bc.config.MaxFreeTransactionsPerBlock - if maxFree != 0 { - lowStart := sort.Search(len(txes), func(i int) bool { - return bc.IsLowPriority(txes[i].Fee) + if maxFree != 0 && len(txes) > maxFree { + // Transactions are sorted by fee, so we just find the first free one. + freeStart := sort.Search(len(txes), func(i int) bool { + return txes[i].Fee == 0 }) - if lowStart+maxFree < len(txes) { - txes = txes[:lowStart+maxFree] + if freeStart+maxFree < len(txes) { + txes = txes[:freeStart+maxFree] } } return txes @@ -1732,6 +1788,90 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { } +// StateHeight returns height of the verified state root. +func (bc *Blockchain) StateHeight() uint32 { + h, _ := bc.dao.GetCurrentStateRootHeight() + return h +} + +// AddStateRoot add new (possibly unverified) state root to the blockchain. +func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { + if !bc.config.EnableStateRoot { + bc.log.Warn("state root is being added but not enabled in config") + return nil + } + our, err := bc.GetStateRoot(r.Index) + if err == nil { + if our.Flag == state.Verified { + return bc.updateStateHeight(r.Index) + } else if r.Witness == nil && our.Witness != nil { + r.Witness = our.Witness + } + } + if err := bc.verifyStateRoot(r); err != nil { + return errors.WithMessage(err, "invalid state root") + } + if r.Index > bc.BlockHeight() { // just put it into the store for future checks + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: state.Unverified, + }) + } + + flag := state.Unverified + if r.Witness != nil { + if err := bc.verifyStateRootWitness(r); err != nil { + return errors.WithMessage(err, "can't verify signature") + } + flag = state.Verified + } + err = bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: flag, + }) + if err != nil { + return err + } + return bc.updateStateHeight(r.Index) +} + +func (bc *Blockchain) updateStateHeight(newHeight uint32) error { + h, err := bc.dao.GetCurrentStateRootHeight() + if err != nil { + return errors.WithMessage(err, "can't get current state root height") + } else if newHeight == h+1 { + updateStateHeightMetric(newHeight) + return bc.dao.PutCurrentStateRootHeight(h + 1) + } + return nil +} + +// verifyStateRoot checks if state root is valid. +func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error { + if r.Index == 0 { + return nil + } + prev, err := bc.GetStateRoot(r.Index - 1) + if err != nil { + return errors.New("can't get previous state root") + } else if !r.PrevHash.Equals(hash.DoubleSha256(prev.GetSignedPart())) { + return errors.New("previous hash mismatch") + } else if prev.Version != r.Version { + return errors.New("version mismatch") + } + return nil +} + +// verifyStateRootWitness verifies that state root signature is correct. +func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error { + b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index))) + if err != nil { + return err + } + interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil) + return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, hash.Sha256(r.GetSignedPart()), interopCtx, true) +} + // VerifyTx verifies whether a transaction is bonafide or not. Block parameter // is used for easy interop access and can be omitted for transactions that are // not yet added into any block. diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index d3e0309de..d32a3eb5e 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -18,6 +18,7 @@ type Blockchainer interface { GetConfig() config.ProtocolConfiguration AddHeaders(...*block.Header) error AddBlock(*block.Block) error + AddStateRoot(r *state.MPTRoot) error BlockHeight() uint32 CalculateClaimable(value util.Fixed8, startHeight, endHeight uint32) (util.Fixed8, util.Fixed8, error) Close() @@ -38,6 +39,8 @@ type Blockchainer interface { GetNEP5Balances(util.Uint160) *state.NEP5Balances GetValidators(txes ...*transaction.Transaction) ([]*keys.PublicKey, error) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) + GetStateProof(root util.Uint256, key []byte) ([][]byte, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error) GetTestVM() *vm.VM @@ -46,6 +49,7 @@ type Blockchainer interface { References(t *transaction.Transaction) ([]transaction.InOut, error) mempool.Feer // fee interface PoolTx(*transaction.Transaction) error + StateHeight() uint32 SubscribeForBlocks(ch chan<- *block.Block) SubscribeForExecutions(ch chan<- *state.AppExecResult) SubscribeForNotifications(ch chan<- *state.NotificationEvent) diff --git a/pkg/consensus/cache.go b/pkg/core/cache/cache.go similarity index 60% rename from pkg/consensus/cache.go rename to pkg/core/cache/cache.go index 4a6853803..962b779ed 100644 --- a/pkg/consensus/cache.go +++ b/pkg/core/cache/cache.go @@ -1,4 +1,4 @@ -package consensus +package cache import ( "container/list" @@ -7,9 +7,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" ) -// relayCache is a payload cache which is used to store +// HashCache is a payload cache which is used to store // last consensus payloads. -type relayCache struct { +type HashCache struct { *sync.RWMutex maxCap int @@ -17,13 +17,14 @@ type relayCache struct { queue *list.List } -// hashable is a type of items which can be stored in the relayCache. -type hashable interface { +// Hashable is a type of items which can be stored in the HashCache. +type Hashable interface { Hash() util.Uint256 } -func newFIFOCache(capacity int) *relayCache { - return &relayCache{ +// NewFIFOCache returns new FIFO cache with the specified capacity. +func NewFIFOCache(capacity int) *HashCache { + return &HashCache{ RWMutex: new(sync.RWMutex), maxCap: capacity, @@ -33,7 +34,7 @@ func newFIFOCache(capacity int) *relayCache { } // Add adds payload into a cache if it doesn't already exist. -func (c *relayCache) Add(p hashable) { +func (c *HashCache) Add(p Hashable) { c.Lock() defer c.Unlock() @@ -45,7 +46,7 @@ func (c *relayCache) Add(p hashable) { if c.queue.Len() >= c.maxCap { first := c.queue.Front() c.queue.Remove(first) - delete(c.elems, first.Value.(hashable).Hash()) + delete(c.elems, first.Value.(Hashable).Hash()) } e := c.queue.PushBack(p) @@ -53,7 +54,7 @@ func (c *relayCache) Add(p hashable) { } // Has checks if an item is already in cache. -func (c *relayCache) Has(h util.Uint256) bool { +func (c *HashCache) Has(h util.Uint256) bool { c.RLock() defer c.RUnlock() @@ -61,13 +62,13 @@ func (c *relayCache) Has(h util.Uint256) bool { } // Get returns payload with the specified hash from cache. -func (c *relayCache) Get(h util.Uint256) hashable { +func (c *HashCache) Get(h util.Uint256) Hashable { c.RLock() defer c.RUnlock() e, ok := c.elems[h] if !ok { - return hashable(nil) + return Hashable(nil) } - return e.Value.(hashable) + return e.Value.(Hashable) } diff --git a/pkg/consensus/cache_test.go b/pkg/core/cache/cache_test.go similarity index 68% rename from pkg/consensus/cache_test.go rename to pkg/core/cache/cache_test.go index cd4ebe5a3..e8288e2d7 100644 --- a/pkg/consensus/cache_test.go +++ b/pkg/core/cache/cache_test.go @@ -1,17 +1,19 @@ -package consensus +package cache import ( + "math/rand" "testing" - "github.com/nspcc-dev/dbft/payload" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) func TestRelayCache_Add(t *testing.T) { const capacity = 3 - payloads := getDifferentPayloads(t, capacity+1) - c := newFIFOCache(capacity) + payloads := getDifferentItems(t, capacity+1) + c := NewFIFOCache(capacity) require.Equal(t, 0, c.queue.Len()) require.Equal(t, 0, len(c.elems)) @@ -46,19 +48,15 @@ func TestRelayCache_Add(t *testing.T) { require.Equal(t, nil, c.Get(payloads[1].Hash())) } -func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) { - payloads = make([]Payload, n) - for i := range payloads { - var sign [signatureSize]byte - random.Fill(sign[:]) +type testHashable []byte - payloads[i].message = &message{} - payloads[i].SetValidatorIndex(uint16(i)) - payloads[i].SetType(payload.MessageType(commitType)) - payloads[i].payload = &commit{ - signature: sign, - } +// Hash implements Hashable. +func (h testHashable) Hash() util.Uint256 { return hash.Sha256(h) } + +func getDifferentItems(t *testing.T, n int) []testHashable { + items := make([]testHashable, n) + for i := range items { + items[i] = random.Bytes(rand.Int() % 10) } - - return + return items } diff --git a/pkg/core/dao/cacheddao.go b/pkg/core/dao/cacheddao.go index 0fe8eba62..fc905f919 100644 --- a/pkg/core/dao/cacheddao.go +++ b/pkg/core/dao/cacheddao.go @@ -245,6 +245,7 @@ func (cd *Cached) FlushStorage() error { return err } } + ti.State |= flushedState } } return nil @@ -275,7 +276,7 @@ func (cd *Cached) getStorageItemNoCache(scripthash util.Uint160, key []byte) *st func (cd *Cached) getStorageItemInt(scripthash util.Uint160, key []byte, putToCache bool) *state.StorageItem { ti := cd.storage.getItem(scripthash, key) if ti != nil { - if ti.State == delOp { + if ti.State&delOp != 0 { return nil } return copyItem(&ti.StorageItem) @@ -303,8 +304,10 @@ func (cd *Cached) PutStorageItem(scripthash util.Uint160, key []byte, si *state. item := copyItem(si) ti := cd.storage.getItem(scripthash, key) if ti != nil { - if ti.State == delOp || ti.State == getOp { + if ti.State&(delOp|getOp) != 0 { ti.State = putOp + } else { + ti.State = addOp } ti.StorageItem = *item return nil @@ -357,7 +360,7 @@ func (cd *Cached) GetStorageItemsIterator(hash util.Uint160, prefix []byte) (Sto for ; keyIndex < len(cd.storage.keys[hash]); keyIndex++ { k := cd.storage.keys[hash][keyIndex] v := cache[k] - if v.State != delOp && bytes.HasPrefix([]byte(k), prefix) { + if v.State&delOp == 0 && bytes.HasPrefix([]byte(k), prefix) { val := make([]byte, len(v.StorageItem.Value)) copy(val, v.StorageItem.Value) return []byte(k), val, nil @@ -404,7 +407,7 @@ func (cd *Cached) GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageIt for _, k := range cd.storage.keys[hash] { v := cache[k] - if v.State != delOp { + if v.State&delOp == 0 { val := make([]byte, len(v.StorageItem.Value)) copy(val, v.StorageItem.Value) result = append(result, StorageItemWithKey{ diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 7126969ac..262584d47 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -7,6 +7,7 @@ import ( "sort" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -31,9 +32,12 @@ type DAO interface { GetContractState(hash util.Uint160) (*state.Contract, error) GetCurrentBlockHeight() (uint32, error) GetCurrentHeaderHeight() (i uint32, h util.Uint256, err error) + GetCurrentStateRootHeight() (uint32, error) GetHeaderHashes() ([]util.Uint256, error) GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) + PutStateRoot(root *state.MPTRootState) error GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageItemWithKey, error) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) @@ -70,12 +74,14 @@ type DAO interface { // Simple is memCached wrapper around DB, simple DAO implementation. type Simple struct { + MPT *mpt.Trie Store *storage.MemCachedStore } // NewSimple creates new simple dao using provided backend store. func NewSimple(backend storage.Store) *Simple { - return &Simple{Store: storage.NewMemCachedStore(backend)} + st := storage.NewMemCachedStore(backend) + return &Simple{Store: st, MPT: mpt.NewTrie(nil, st)} } // GetBatch returns currently accumulated DB changeset. @@ -86,7 +92,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch { // GetWrapped returns new DAO instance with another layer of wrapped // MemCachedStore around the current DAO Store. func (dao *Simple) GetWrapped() DAO { - return NewSimple(dao.Store) + d := NewSimple(dao.Store) + d.MPT = dao.MPT + return d } // GetAndDecode performs get operation and decoding with serializable structures. @@ -406,6 +414,63 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error { // -- start storage item. +func makeStateRootKey(height uint32) []byte { + key := make([]byte, 5) + key[0] = byte(storage.DataMPT) + binary.LittleEndian.PutUint32(key[1:], height) + return key +} + +// InitMPT initializes MPT at the given height. +func (dao *Simple) InitMPT(height uint32) error { + if height == 0 { + dao.MPT = mpt.NewTrie(nil, dao.Store) + return nil + } + r, err := dao.GetStateRoot(height) + if err != nil { + return err + } + dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) + return nil +} + +// GetCurrentStateRootHeight returns current state root height. +func (dao *Simple) GetCurrentStateRootHeight() (uint32, error) { + key := []byte{byte(storage.DataMPT)} + val, err := dao.Store.Get(key) + if err != nil { + if err == storage.ErrKeyNotFound { + err = nil + } + return 0, err + } + return binary.LittleEndian.Uint32(val), nil +} + +// PutCurrentStateRootHeight updates current state root height. +func (dao *Simple) PutCurrentStateRootHeight(height uint32) error { + key := []byte{byte(storage.DataMPT)} + val := make([]byte, 4) + binary.LittleEndian.PutUint32(val, height) + return dao.Store.Put(key, val) +} + +// GetStateRoot returns state root of a given height. +func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { + r := new(state.MPTRootState) + err := dao.GetAndDecode(r, makeStateRootKey(height)) + if err != nil { + return nil, err + } + return r, nil +} + +// PutStateRoot puts state root of a given height into the store. +func (dao *Simple) PutStateRoot(r *state.MPTRootState) error { + return dao.Put(r, makeStateRootKey(r.Index)) +} + // GetStorageItem returns StorageItem if it exists in the given store. func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem { b, err := dao.Store.Get(makeStorageItemKey(scripthash, key)) @@ -426,13 +491,24 @@ func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.St // PutStorageItem puts given StorageItem for given script with given // key into the given store. func (dao *Simple) PutStorageItem(scripthash util.Uint160, key []byte, si *state.StorageItem) error { - return dao.Put(si, makeStorageItemKey(scripthash, key)) + stKey := makeStorageItemKey(scripthash, key) + k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix + v := mpt.ToNeoStorageValue(si) + if err := dao.MPT.Put(k, v); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Put(stKey, v[1:]) } // DeleteStorageItem drops storage item for the given script with the // given key from the store. func (dao *Simple) DeleteStorageItem(scripthash util.Uint160, key []byte) error { - return dao.Store.Delete(makeStorageItemKey(scripthash, key)) + stKey := makeStorageItemKey(scripthash, key) + k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix + if err := dao.MPT.Delete(k); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Delete(stKey) } // StorageItemWithKey is a Key-Value pair together with possible const modifier. diff --git a/pkg/core/dao/storage_item.go b/pkg/core/dao/storage_item.go index 5a961e6bc..ade05cfd6 100644 --- a/pkg/core/dao/storage_item.go +++ b/pkg/core/dao/storage_item.go @@ -24,6 +24,7 @@ const ( delOp addOp putOp + flushedState ) func newItemCache() *itemCache { diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index cd3c0d8fd..418f943a6 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -1,10 +1,13 @@ package core import ( + "crypto/elliptic" "errors" "fmt" "math" + "math/big" + "github.com/btcsuite/btcd/btcec" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -600,6 +603,33 @@ func (ic *interopContext) contractMigrate(v *vm.VM) error { return ic.contractDestroy(v) } +// secp256k1Recover recovers speck256k1 public key. +func (ic *interopContext) secp256k1Recover(v *vm.VM) error { + return ic.eccRecover(btcec.S256(), v) +} + +// secp256r1Recover recovers speck256r1 public key. +func (ic *interopContext) secp256r1Recover(v *vm.VM) error { + return ic.eccRecover(elliptic.P256(), v) +} + +// eccRecover recovers public key using ECCurve set +func (ic *interopContext) eccRecover(curve elliptic.Curve, v *vm.VM) error { + rBytes := v.Estack().Pop().Bytes() + sBytes := v.Estack().Pop().Bytes() + r := new(big.Int).SetBytes(rBytes) + s := new(big.Int).SetBytes(sBytes) + isEven := v.Estack().Pop().Bool() + messageHash := v.Estack().Pop().Bytes() + pKey, err := keys.KeyRecover(curve, r, s, messageHash, isEven) + if err != nil { + v.Estack().PushVal([]byte{}) + return nil + } + v.Estack().PushVal(pKey.UncompressedBytes()[1:]) + return nil +} + // assetCreate creates an asset. func (ic *interopContext) assetCreate(v *vm.VM) error { if ic.trigger != trigger.Application { diff --git a/pkg/core/interop_neo_test.go b/pkg/core/interop_neo_test.go index 135302927..d03a864f5 100644 --- a/pkg/core/interop_neo_test.go +++ b/pkg/core/interop_neo_test.go @@ -1,14 +1,17 @@ package core import ( + "bytes" "math/big" "testing" + "github.com/btcsuite/btcd/btcec" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -457,8 +460,81 @@ func TestAssetGetPrecision(t *testing.T) { require.Equal(t, big.NewInt(int64(assetState.Precision)), precision) } +func TestSecp256k1Recover(t *testing.T) { + v, context, chain := createVM(t) + defer chain.Close() + + privateKey, err := btcec.NewPrivateKey(btcec.S256()) + require.NoError(t, err) + message := []byte("The quick brown fox jumps over the lazy dog") + signature, err := privateKey.Sign(message) + require.NoError(t, err) + require.True(t, signature.Verify(message, privateKey.PubKey())) + pubKey := keys.PublicKey{ + X: privateKey.PubKey().X, + Y: privateKey.PubKey().Y, + } + expected := pubKey.UncompressedBytes()[1:] + + // We don't know which of two recovered keys suites, so let's try both. + putOnStackGetResult := func(isEven bool) []byte { + v.Estack().PushVal(message) + v.Estack().PushVal(isEven) + v.Estack().PushVal(signature.S.Bytes()) + v.Estack().PushVal(signature.R.Bytes()) + err = context.secp256k1Recover(v) + require.NoError(t, err) + return v.Estack().Pop().Value().([]byte) + } + + // First one: + actualFalse := putOnStackGetResult(false) + // Second one: + actualTrue := putOnStackGetResult(true) + + require.True(t, bytes.Compare(expected, actualFalse) != bytes.Compare(expected, actualTrue)) +} + +func TestSecp256r1Recover(t *testing.T) { + v, context, chain := createVM(t) + defer chain.Close() + + privateKey, err := keys.NewPrivateKey() + require.NoError(t, err) + message := []byte("The quick brown fox jumps over the lazy dog") + messageHash := hash.Sha256(message).BytesBE() + signature := privateKey.Sign(message) + require.True(t, privateKey.PublicKey().Verify(signature, messageHash)) + expected := privateKey.PublicKey().UncompressedBytes()[1:] + + // We don't know which of two recovered keys suites, so let's try both. + putOnStackGetResult := func(isEven bool) []byte { + v.Estack().PushVal(messageHash) + v.Estack().PushVal(isEven) + v.Estack().PushVal(signature[32:64]) + v.Estack().PushVal(signature[0:32]) + err = context.secp256r1Recover(v) + require.NoError(t, err) + return v.Estack().Pop().Value().([]byte) + } + + // First one: + actualFalse := putOnStackGetResult(false) + // Second one: + actualTrue := putOnStackGetResult(true) + + require.True(t, bytes.Compare(expected, actualFalse) != bytes.Compare(expected, actualTrue)) +} + // Helper functions to create VM, InteropContext, TX, Account, Contract, Asset. +func createVM(t *testing.T) (*vm.VM, *interopContext, *Blockchain) { + v := vm.New() + chain := newTestChain(t) + context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore()), nil, nil) + return v, context, chain +} + func createVMAndPushBlock(t *testing.T) (*vm.VM, *block.Block, *interopContext, *Blockchain) { v := vm.New() block := newDumbBlock() diff --git a/pkg/core/interops.go b/pkg/core/interops.go index 25040bce6..940d5928b 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -52,6 +52,9 @@ func (ic *interopContext) SpawnVM() *vm.VM { }) vm.RegisterInteropGetter(ic.getSystemInterop) vm.RegisterInteropGetter(ic.getNeoInterop) + if ic.bc != nil && ic.bc.GetConfig().EnableStateRoot { + vm.RegisterInteropGetter(ic.getNeoxInterop) + } return vm } @@ -77,6 +80,12 @@ func (ic *interopContext) getNeoInterop(id uint32) *vm.InteropFuncPrice { return ic.getInteropFromSlice(id, neoInterops) } +// getNeoxInterop returns matching interop function from the NeoX extension +// for a given id in the current context. +func (ic *interopContext) getNeoxInterop(id uint32) *vm.InteropFuncPrice { + return ic.getInteropFromSlice(id, neoxInterops) +} + // getInteropFromSlice returns matching interop function from the given slice of // interop functions in the current context. func (ic *interopContext) getInteropFromSlice(id uint32, slice []interopedFunction) *vm.InteropFuncPrice { @@ -276,6 +285,11 @@ var neoInterops = []interopedFunction{ {Name: "AntShares.Transaction.GetType", Func: (*interopContext).txGetType, Price: 1}, } +var neoxInterops = []interopedFunction{ + {Name: "Neo.Cryptography.Secp256k1Recover", Func: (*interopContext).secp256k1Recover, Price: 100}, + {Name: "Neo.Cryptography.Secp256r1Recover", Func: (*interopContext).secp256r1Recover, Price: 100}, +} + // initIDinInteropsSlice initializes IDs from names in one given // interopedFunction slice and then sorts it. func initIDinInteropsSlice(iops []interopedFunction) { @@ -291,4 +305,5 @@ func initIDinInteropsSlice(iops []interopedFunction) { func init() { initIDinInteropsSlice(systemInterops) initIDinInteropsSlice(neoInterops) + initIDinInteropsSlice(neoxInterops) } diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go new file mode 100644 index 000000000..9f10cc333 --- /dev/null +++ b/pkg/core/mpt/base.go @@ -0,0 +1,84 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// BaseNode implements basic things every node needs like caching hash and +// serialized representation. It's a basic node building block intended to be +// included into all node types. +type BaseNode struct { + hash util.Uint256 + bytes []byte + hashValid bool + bytesValid bool + + isFlushed bool +} + +// BaseNodeIface abstracts away basic Node functions. +type BaseNodeIface interface { + Hash() util.Uint256 + Type() NodeType + Bytes() []byte + IsFlushed() bool + SetFlushed() +} + +// getHash returns a hash of this BaseNode. +func (b *BaseNode) getHash(n Node) util.Uint256 { + if !b.hashValid { + b.updateHash(n) + } + return b.hash +} + +// getBytes returns a slice of bytes representing this node. +func (b *BaseNode) getBytes(n Node) []byte { + if !b.bytesValid { + b.updateBytes(n) + } + return b.bytes +} + +// updateHash updates hash field for this BaseNode. +func (b *BaseNode) updateHash(n Node) { + if n.Type() == HashT { + panic("can't update hash for hash node") + } + b.hash = hash.DoubleSha256(b.getBytes(n)) + b.hashValid = true +} + +// updateCache updates hash and bytes fields for this BaseNode. +func (b *BaseNode) updateBytes(n Node) { + buf := io.NewBufBinWriter() + encodeNodeWithType(n, buf.BinWriter) + b.bytes = buf.Bytes() + b.bytesValid = true +} + +// invalidateCache sets all cache fields to invalid state. +func (b *BaseNode) invalidateCache() { + b.bytesValid = false + b.hashValid = false + b.isFlushed = false +} + +// IsFlushed checks for node flush status. +func (b *BaseNode) IsFlushed() bool { + return b.isFlushed +} + +// SetFlushed sets 'flushed' flag to true for this node. +func (b *BaseNode) SetFlushed() { + b.isFlushed = true +} + +// encodeNodeWithType encodes node together with it's type. +func encodeNodeWithType(n Node, w *io.BinWriter) { + w.WriteB(byte(n.Type())) + n.EncodeBinary(w) +} diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go new file mode 100644 index 000000000..fbad5d29e --- /dev/null +++ b/pkg/core/mpt/branch.go @@ -0,0 +1,91 @@ +package mpt + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +const ( + // childrenCount represents a number of children of a branch node. + childrenCount = 17 + // lastChild is the index of the last child. + lastChild = childrenCount - 1 +) + +// BranchNode represents MPT's branch node. +type BranchNode struct { + BaseNode + Children [childrenCount]Node +} + +var _ Node = (*BranchNode)(nil) + +// NewBranchNode returns new branch node. +func NewBranchNode() *BranchNode { + b := new(BranchNode) + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + } + return b +} + +// Type implements Node interface. +func (b *BranchNode) Type() NodeType { return BranchT } + +// Hash implements BaseNode interface. +func (b *BranchNode) Hash() util.Uint256 { + return b.getHash(b) +} + +// Bytes implements BaseNode interface. +func (b *BranchNode) Bytes() []byte { + return b.getBytes(b) +} + +// EncodeBinary implements io.Serializable. +func (b *BranchNode) EncodeBinary(w *io.BinWriter) { + for i := 0; i < childrenCount; i++ { + if hn, ok := b.Children[i].(*HashNode); ok { + hn.EncodeBinary(w) + continue + } + n := NewHashNode(b.Children[i].Hash()) + n.EncodeBinary(w) + } +} + +// DecodeBinary implements io.Serializable. +func (b *BranchNode) DecodeBinary(r *io.BinReader) { + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + b.Children[i].DecodeBinary(r) + } +} + +// MarshalJSON implements json.Marshaler. +func (b *BranchNode) MarshalJSON() ([]byte, error) { + return json.Marshal(b.Children) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *BranchNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*BranchNode); ok { + *b = *u + return nil + } + return errors.New("expected branch node") +} + +// splitPath splits path for a branch node. +func splitPath(path []byte) (byte, []byte) { + if len(path) != 0 { + return path[0], path[1:] + } + return lastChild, path +} diff --git a/pkg/core/mpt/doc.go b/pkg/core/mpt/doc.go new file mode 100644 index 000000000..c307665b3 --- /dev/null +++ b/pkg/core/mpt/doc.go @@ -0,0 +1,45 @@ +/* +Package mpt implements MPT (Merkle-Patricia Tree). + +MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie +Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node. +MPT consists of 4 type of nodes: +- Leaf node contains only value. +- Extension node contains both key and value. +- Branch node contains 2 or more children. +- Hash node is a compressed node and contains only actual node's hash. + The actual node must be retrieved from storage or over the network. + +As an example here is a trie containing 3 pairs: +- 0x1201 -> val1 +- 0x1203 -> val2 +- 0x1224 -> val3 +- 0x12 -> val4 + +ExtensionNode(0x0102), Next + _______________________| + | +BranchNode [0, 1, 2, ...], Last -> Leaf(val4) + | | + | ExtensionNode [0x04], Next -> Leaf(val3) + | + BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil) + | | + | Leaf(val2) + | + Leaf(val1) + +There are 3 invariants that this implementation has: +- Branch node cannot have <= 1 children +- Extension node cannot have zero-length key +- Extension node cannot have another Extension node in it's next field + +Thank to these restrictions, there is a single root hash for every set of key-value pairs +irregardless of the order they were added/removed with. +The actual trie structure can vary because of node -> HashNode compressing. + +There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial: +When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is +replaced by its uncompressed form, so that subsequent hits of this not don't use storage. +*/ +package mpt diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go new file mode 100644 index 000000000..8bcc11c24 --- /dev/null +++ b/pkg/core/mpt/extension.go @@ -0,0 +1,87 @@ +package mpt + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxKeyLength is the max length of the extension node key. +const MaxKeyLength = 1125 + +// ExtensionNode represents MPT's extension node. +type ExtensionNode struct { + BaseNode + key []byte + next Node +} + +var _ Node = (*ExtensionNode)(nil) + +// NewExtensionNode returns hash node with the specified key and next node. +// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0. +func NewExtensionNode(key []byte, next Node) *ExtensionNode { + return &ExtensionNode{ + key: key, + next: next, + } +} + +// Type implements Node interface. +func (e ExtensionNode) Type() NodeType { return ExtensionT } + +// Hash implements BaseNode interface. +func (e *ExtensionNode) Hash() util.Uint256 { + return e.getHash(e) +} + +// Bytes implements BaseNode interface. +func (e *ExtensionNode) Bytes() []byte { + return e.getBytes(e) +} + +// DecodeBinary implements io.Serializable. +func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxKeyLength { + r.Err = fmt.Errorf("extension node key is too big: %d", sz) + return + } + e.key = make([]byte, sz) + r.ReadBytes(e.key) + e.next = new(HashNode) + e.next.DecodeBinary(r) + e.invalidateCache() +} + +// EncodeBinary implements io.Serializable. +func (e ExtensionNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(e.key) + n := NewHashNode(e.next.Hash()) + n.EncodeBinary(w) +} + +// MarshalJSON implements json.Marshaler. +func (e *ExtensionNode) MarshalJSON() ([]byte, error) { + m := map[string]interface{}{ + "key": hex.EncodeToString(e.key), + "next": e.next, + } + return json.Marshal(m) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (e *ExtensionNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*ExtensionNode); ok { + *e = *u + return nil + } + return errors.New("expected extension node") +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go new file mode 100644 index 000000000..42519a1ac --- /dev/null +++ b/pkg/core/mpt/hash.go @@ -0,0 +1,88 @@ +package mpt + +import ( + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// HashNode represents MPT's hash node. +type HashNode struct { + BaseNode +} + +var _ Node = (*HashNode)(nil) + +// NewHashNode returns hash node with the specified hash. +func NewHashNode(h util.Uint256) *HashNode { + return &HashNode{ + BaseNode: BaseNode{ + hash: h, + hashValid: true, + }, + } +} + +// Type implements Node interface. +func (h *HashNode) Type() NodeType { return HashT } + +// Hash implements Node interface. +func (h *HashNode) Hash() util.Uint256 { + if !h.hashValid { + panic("can't get hash of an empty HashNode") + } + return h.hash +} + +// IsEmpty returns true iff h is an empty node i.e. contains no hash. +func (h *HashNode) IsEmpty() bool { return !h.hashValid } + +// Bytes returns serialized HashNode. +func (h *HashNode) Bytes() []byte { + return h.getBytes(h) +} + +// DecodeBinary implements io.Serializable. +func (h *HashNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + switch sz { + case 0: + h.hashValid = false + case util.Uint256Size: + h.hashValid = true + r.ReadBytes(h.hash[:]) + default: + r.Err = fmt.Errorf("invalid hash node size: %d", sz) + } +} + +// EncodeBinary implements io.Serializable. +func (h HashNode) EncodeBinary(w *io.BinWriter) { + if !h.hashValid { + w.WriteVarUint(0) + return + } + w.WriteVarBytes(h.hash[:]) +} + +// MarshalJSON implements json.Marshaler. +func (h *HashNode) MarshalJSON() ([]byte, error) { + if !h.hashValid { + return []byte(`{}`), nil + } + return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (h *HashNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*HashNode); ok { + *h = *u + return nil + } + return errors.New("expected hash node") +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go new file mode 100644 index 000000000..4f508445d --- /dev/null +++ b/pkg/core/mpt/helpers.go @@ -0,0 +1,86 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// lcp returns longest common prefix of a and b. +// Note: it does no allocations. +func lcp(a, b []byte) []byte { + if len(a) < len(b) { + return lcp(b, a) + } + + var i int + for i = 0; i < len(b); i++ { + if a[i] != b[i] { + break + } + } + + return a[:i] +} + +// copySlice is a helper for copying slice if needed. +func copySlice(a []byte) []byte { + b := make([]byte, len(a)) + copy(b, a) + return b +} + +// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part. +func toNibbles(path []byte) []byte { + result := make([]byte, len(path)*2) + for i := range path { + result[i*2] = path[i] >> 4 + result[i*2+1] = path[i] & 0x0F + } + return result +} + +// ToNeoStorageKey converts storage key to C# neo node's format. +// Key is expected to be at least 20 bytes in length. +// our format: script hash in BE + key +// neo format: script hash in LE + key with 0 between every 16 bytes, padded to len 16. +func ToNeoStorageKey(key []byte) []byte { + const groupSize = 16 + + var nkey []byte + for i := util.Uint160Size - 1; i >= 0; i-- { + nkey = append(nkey, key[i]) + } + + key = key[util.Uint160Size:] + + index := 0 + remain := len(key) + for remain >= groupSize { + nkey = append(nkey, key[index:index+groupSize]...) + nkey = append(nkey, 0) + index += groupSize + remain -= groupSize + } + + if remain > 0 { + nkey = append(nkey, key[index:]...) + } + + padding := groupSize - remain + for i := 0; i < padding; i++ { + nkey = append(nkey, 0) + } + return append(nkey, byte(padding)) +} + +// ToNeoStorageValue serializes si to a C# neo node's format. +// It has additional version (0x00) byte at the beginning. +func ToNeoStorageValue(si *state.StorageItem) []byte { + const version = 0 + + buf := io.NewBufBinWriter() + buf.BinWriter.WriteB(version) + si.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} diff --git a/pkg/core/mpt/helpers_test.go b/pkg/core/mpt/helpers_test.go new file mode 100644 index 000000000..a27542e8f --- /dev/null +++ b/pkg/core/mpt/helpers_test.go @@ -0,0 +1,30 @@ +package mpt + +import ( + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToNeoStorageKey(t *testing.T) { + testCases := []struct{ key, res string }{ + { + "0102030405060708091011121314151617181920", + "20191817161514131211100908070605040302010000000000000000000000000000000010", + }, + { + "010203040506070809101112131415161718192021222324", + "2019181716151413121110090807060504030201212223240000000000000000000000000c", + }, + { + "0102030405060708091011121314151617181920212223242526272829303132333435363738", + "20191817161514131211100908070605040302012122232425262728293031323334353600373800000000000000000000000000000e", + }, + } + for _, tc := range testCases { + key, _ := hex.DecodeString(tc.key) + res, _ := hex.DecodeString(tc.res) + require.Equal(t, res, ToNeoStorageKey(key)) + } +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go new file mode 100644 index 000000000..82dd8eef6 --- /dev/null +++ b/pkg/core/mpt/leaf.go @@ -0,0 +1,73 @@ +package mpt + +import ( + "encoding/hex" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxValueLength is a max length of a leaf node value. +const MaxValueLength = 1024 * 1024 + +// LeafNode represents MPT's leaf node. +type LeafNode struct { + BaseNode + value []byte +} + +var _ Node = (*LeafNode)(nil) + +// NewLeafNode returns hash node with the specified value. +func NewLeafNode(value []byte) *LeafNode { + return &LeafNode{value: value} +} + +// Type implements Node interface. +func (n LeafNode) Type() NodeType { return LeafT } + +// Hash implements BaseNode interface. +func (n *LeafNode) Hash() util.Uint256 { + return n.getHash(n) +} + +// Bytes implements BaseNode interface. +func (n *LeafNode) Bytes() []byte { + return n.getBytes(n) +} + +// DecodeBinary implements io.Serializable. +func (n *LeafNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxValueLength { + r.Err = fmt.Errorf("leaf node value is too big: %d", sz) + return + } + n.value = make([]byte, sz) + r.ReadBytes(n.value) + n.invalidateCache() +} + +// EncodeBinary implements io.Serializable. +func (n LeafNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(n.value) +} + +// MarshalJSON implements json.Marshaler. +func (n *LeafNode) MarshalJSON() ([]byte, error) { + return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *LeafNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*LeafNode); ok { + *n = *u + return nil + } + return errors.New("expected leaf node") +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go new file mode 100644 index 000000000..86e675a01 --- /dev/null +++ b/pkg/core/mpt/node.go @@ -0,0 +1,134 @@ +package mpt + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// NodeType represents node type.. +type NodeType byte + +// Node types definitions. +const ( + BranchT NodeType = 0x00 + ExtensionT NodeType = 0x01 + HashT NodeType = 0x02 + LeafT NodeType = 0x03 +) + +// NodeObject represents Node together with it's type. +// It is used for serialization/deserialization where type info +// is also expected. +type NodeObject struct { + Node +} + +// Node represents common interface of all MPT nodes. +type Node interface { + io.Serializable + json.Marshaler + json.Unmarshaler + BaseNodeIface +} + +// EncodeBinary implements io.Serializable. +func (n NodeObject) EncodeBinary(w *io.BinWriter) { + encodeNodeWithType(n.Node, w) +} + +// DecodeBinary implements io.Serializable. +func (n *NodeObject) DecodeBinary(r *io.BinReader) { + typ := NodeType(r.ReadB()) + switch typ { + case BranchT: + n.Node = new(BranchNode) + case ExtensionT: + n.Node = new(ExtensionNode) + case HashT: + n.Node = new(HashNode) + case LeafT: + n.Node = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return + } + n.Node.DecodeBinary(r) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *NodeObject) UnmarshalJSON(data []byte) error { + var m map[string]json.RawMessage + err := json.Unmarshal(data, &m) + if err != nil { // it can be a branch node + var nodes []NodeObject + if err := json.Unmarshal(data, &nodes); err != nil { + return err + } else if len(nodes) != childrenCount { + return errors.New("invalid length of branch node") + } + + b := NewBranchNode() + for i := range b.Children { + b.Children[i] = nodes[i].Node + } + n.Node = b + return nil + } + + switch len(m) { + case 0: + n.Node = new(HashNode) + case 1: + if v, ok := m["hash"]; ok { + var h util.Uint256 + if err := json.Unmarshal(v, &h); err != nil { + return err + } + n.Node = NewHashNode(h) + } else if v, ok = m["value"]; ok { + b, err := unmarshalHex(v) + if err != nil { + return err + } else if len(b) > MaxValueLength { + return errors.New("leaf value is too big") + } + n.Node = NewLeafNode(b) + } else { + return errors.New("invalid field") + } + case 2: + keyRaw, ok1 := m["key"] + nextRaw, ok2 := m["next"] + if !ok1 || !ok2 { + return errors.New("invalid field") + } + key, err := unmarshalHex(keyRaw) + if err != nil { + return err + } else if len(key) > MaxKeyLength { + return errors.New("extension key is too big") + } + + var next NodeObject + if err := json.Unmarshal(nextRaw, &next); err != nil { + return err + } + n.Node = NewExtensionNode(key, next.Node) + default: + return errors.New("0, 1 or 2 fields expected") + } + return nil +} + +func unmarshalHex(data json.RawMessage) ([]byte, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return hex.DecodeString(s) +} diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go new file mode 100644 index 000000000..e3aab54d6 --- /dev/null +++ b/pkg/core/mpt/node_test.go @@ -0,0 +1,156 @@ +package mpt + +import ( + "encoding/json" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { + return func(t *testing.T) { + t.Run("IO", func(t *testing.T) { + bs, err := testserdes.EncodeBinary(expected) + require.NoError(t, err) + err = testserdes.DecodeBinary(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) + t.Run("JSON", func(t *testing.T) { + bs, err := json.Marshal(expected) + require.NoError(t, err) + err = json.Unmarshal(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) + } +} + +func TestNode_Serializable(t *testing.T) { + t.Run("Leaf", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + l := NewLeafNode(random.Bytes(123)) + t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject))) + }) + t.Run("BigValue", getTestFuncEncode(false, + NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode))) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10))) + t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject))) + }) + t.Run("BigKey", getTestFuncEncode(false, + NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode))) + }) + + t.Run("Branch", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewLeafNode(random.Bytes(10)) + b.Children[lastChild] = NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject))) + }) + + t.Run("Hash", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + h := NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, h, new(HashNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject))) + }) + t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes + testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode)) + }) + t.Run("InvalidSize", func(t *testing.T) { + buf := io.NewBufBinWriter() + buf.BinWriter.WriteVarBytes(make([]byte, 13)) + require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode))) + }) + }) + + t.Run("Invalid", func(t *testing.T) { + require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject))) + }) +} + +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198 +func TestJSONSharp(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11})) + require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22})) + require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac})) + require.NoError(t, tr.Delete([]byte{0xac, 0x11})) + require.NoError(t, tr.Delete([]byte{0xac, 0x22})) + + js, err := tr.root.MarshalJSON() + require.NoError(t, err) + require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js)) +} + +func TestInvalidJSON(t *testing.T) { + t.Run("InvalidChildrenCount", func(t *testing.T) { + var cs [childrenCount + 1]Node + for i := range cs { + cs[i] = new(HashNode) + } + data, err := json.Marshal(cs) + require.NoError(t, err) + + var n NodeObject + require.Error(t, json.Unmarshal(data, &n)) + }) + + testCases := []struct { + name string + data []byte + }{ + {"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)}, + {"InvalidField1", []byte(`{"next":{}}`)}, + {"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)}, + {"InvalidKey", []byte(`{"key":"xy", "next":{}}`)}, + {"InvalidNext", []byte(`{"key":"01", "next":[]}`)}, + {"InvalidHash", []byte(`{"hash":"01"}`)}, + {"InvalidValue", []byte(`{"value":1}`)}, + {"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)}, + } + for _, tc := range testCases { + var n NodeObject + assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name) + } +} + +// C# interoperability test +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 +func TestRootHash(t *testing.T) { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0A, 0x0C}, b) + + v1 := NewLeafNode([]byte{0xAB, 0xCD}) + l1 := NewExtensionNode([]byte{0x01}, v1) + b.Children[0] = l1 + + v2 := NewLeafNode([]byte{0x22, 0x22}) + l2 := NewExtensionNode([]byte{0x09}, v2) + b.Children[9] = l2 + + r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1) + require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE()) + require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE()) +} diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go new file mode 100644 index 000000000..5f8fcdc84 --- /dev/null +++ b/pkg/core/mpt/proof.go @@ -0,0 +1,74 @@ +package mpt + +import ( + "bytes" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// GetProof returns a proof that key belongs to t. +// Proof consist of serialized nodes occuring on path from the root to the leaf of key. +func (t *Trie) GetProof(key []byte) ([][]byte, error) { + var proof [][]byte + path := toNibbles(key) + r, err := t.getProof(t.root, path, &proof) + if err != nil { + return proof, err + } + t.root = r + return proof, nil +} + +func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + *proofs = append(*proofs, copySlice(n.Bytes())) + return n, nil + } + case *BranchNode: + *proofs = append(*proofs, copySlice(n.Bytes())) + i, path := splitPath(path) + r, err := t.getProof(n.Children[i], path, proofs) + if err != nil { + return nil, err + } + n.Children[i] = r + return n, nil + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + *proofs = append(*proofs, copySlice(n.Bytes())) + r, err := t.getProof(n.next, path[len(n.key):], proofs) + if err != nil { + return nil, err + } + n.next = r + return n, nil + } + case *HashNode: + if !n.IsEmpty() { + r, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.getProof(r, path, proofs) + } + } + return nil, ErrNotFound +} + +// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash. +// It also returns value for the key. +func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) { + path := toNibbles(key) + tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore())) + for i := range proofs { + h := hash.DoubleSha256(proofs[i]) + // no errors in Put to memory store + _ = tr.Store.Put(makeStorageKey(h[:]), proofs[i]) + } + _, bs, err := tr.getWithPath(tr.root, path) + return bs, err == nil +} diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go new file mode 100644 index 000000000..17301af15 --- /dev/null +++ b/pkg/core/mpt/proof_test.go @@ -0,0 +1,73 @@ +package mpt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func newProofTrie(t *testing.T) *Trie { + l := NewLeafNode([]byte("somevalue")) + e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) + l2 := NewLeafNode([]byte("invalid")) + e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash())) + b := NewBranchNode() + b.Children[4] = NewHashNode(e.Hash()) + b.Children[5] = e2 + + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1"))) + require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) + tr.putToStore(l) + tr.putToStore(e) + return tr +} + +func TestTrie_GetProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("MissingKey", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12}) + require.Error(t, err) + }) + + t.Run("Valid", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12, 0x31}) + require.NoError(t, err) + }) + + t.Run("MissingHashNode", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x55}) + require.Error(t, err) + }) +} + +func TestVerifyProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("Simple", func(t *testing.T) { + proof, err := tr.GetProof([]byte{0x12, 0x32}) + require.NoError(t, err) + + t.Run("Good", func(t *testing.T) { + v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof) + require.True(t, ok) + require.Equal(t, []byte("value2"), v) + }) + + t.Run("Bad", func(t *testing.T) { + _, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof) + require.False(t, ok) + }) + }) + + t.Run("InsideHash", func(t *testing.T) { + key := []byte{0x45, 0x67} + proof, err := tr.GetProof(key) + require.NoError(t, err) + + v, ok := VerifyProof(tr.root.Hash(), key, proof) + require.True(t, ok) + require.Equal(t, []byte("somevalue"), v) + }) +} diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go new file mode 100644 index 000000000..08d128d88 --- /dev/null +++ b/pkg/core/mpt/trie.go @@ -0,0 +1,390 @@ +package mpt + +import ( + "bytes" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Trie is an MPT trie storing all key-value pairs. +type Trie struct { + Store *storage.MemCachedStore + + root Node +} + +// ErrNotFound is returned when requested trie item is missing. +var ErrNotFound = errors.New("item not found") + +// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors +// so that all storage errors are processed during `store.Persist()` at the caller. +// This also has the benefit, that every `Put` can be considered an atomic operation. +func NewTrie(root Node, store *storage.MemCachedStore) *Trie { + if root == nil { + root = new(HashNode) + } + + return &Trie{ + Store: store, + root: root, + } +} + +// Get returns value for the provided key in t. +func (t *Trie) Get(key []byte) ([]byte, error) { + path := toNibbles(key) + r, bs, err := t.getWithPath(t.root, path) + if err != nil { + return nil, err + } + t.root = r + return bs, nil +} + +// getWithPath returns value the provided path in a subtrie rooting in curr. +// It also returns a current node with all hash nodes along the path +// replaced to their "unhashed" counterparts. +func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return curr, copySlice(n.value), nil + } + case *BranchNode: + i, path := splitPath(path) + r, bs, err := t.getWithPath(n.Children[i], path) + if err != nil { + return nil, nil, err + } + n.Children[i] = r + return n, bs, nil + case *HashNode: + if !n.IsEmpty() { + if r, err := t.getFromStore(n.hash); err == nil { + return t.getWithPath(r, path) + } + } + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + r, bs, err := t.getWithPath(n.next, path[len(n.key):]) + if err != nil { + return nil, nil, err + } + n.next = r + return curr, bs, err + } + default: + panic("invalid MPT node type") + } + return curr, nil, ErrNotFound +} + +// Put puts key-value pair in t. +func (t *Trie) Put(key, value []byte) error { + if len(key) > MaxKeyLength { + return errors.New("key is too big") + } else if len(value) > MaxValueLength { + return errors.New("value is too big") + } + if len(value) == 0 { + return t.Delete(key) + } + path := toNibbles(key) + n := NewLeafNode(value) + r, err := t.putIntoNode(t.root, path, n) + if err != nil { + return err + } + t.root = r + return nil +} + +// putIntoLeaf puts val to trie if current node is a Leaf. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + v := val.(*LeafNode) + if len(path) == 0 { + return v, nil + } + + b := NewBranchNode() + b.Children[path[0]] = newSubTrie(path[1:], v) + b.Children[lastChild] = curr + return b, nil +} + +// putIntoBranch puts val to trie if current node is a Branch. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + i, path := splitPath(path) + r, err := t.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + curr.invalidateCache() + return curr, nil +} + +// putIntoExtension puts val to trie if current node is an Extension. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if bytes.HasPrefix(path, curr.key) { + r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + curr.invalidateCache() + return curr, nil + } + + pref := lcp(curr.key, path) + lp := len(pref) + keyTail := curr.key[lp:] + pathTail := path[lp:] + + s1 := newSubTrie(keyTail[1:], curr.next) + b := NewBranchNode() + b.Children[keyTail[0]] = s1 + + i, pathTail := splitPath(pathTail) + s2 := newSubTrie(pathTail, val) + b.Children[i] = s2 + + if lp > 0 { + return NewExtensionNode(copySlice(pref), b), nil + } + return b, nil +} + +// putIntoHash puts val to trie if current node is a HashNode. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + if curr.IsEmpty() { + return newSubTrie(path, val), nil + } + + result, err := t.getFromStore(curr.hash) + if err != nil { + return nil, err + } + return t.putIntoNode(result, path, val) +} + +// newSubTrie create new trie containing node at provided path. +func newSubTrie(path []byte, val Node) Node { + if len(path) == 0 { + return val + } + return NewExtensionNode(path, val) +} + +func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return t.putIntoLeaf(n, path, val) + case *BranchNode: + return t.putIntoBranch(n, path, val) + case *ExtensionNode: + return t.putIntoExtension(n, path, val) + case *HashNode: + return t.putIntoHash(n, path, val) + default: + panic("invalid MPT node type") + } +} + +// Delete removes key from trie. +// It returns no error on missing key. +func (t *Trie) Delete(key []byte) error { + path := toNibbles(key) + r, err := t.deleteFromNode(t.root, path) + if err != nil { + return err + } + t.root = r + return nil +} + +func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { + i, path := splitPath(path) + r, err := t.deleteFromNode(b.Children[i], path) + if err != nil { + return nil, err + } + b.Children[i] = r + b.invalidateCache() + var count, index int + for i := range b.Children { + h, ok := b.Children[i].(*HashNode) + if !ok || !h.IsEmpty() { + index = i + count++ + } + } + // count is >= 1 because branch node had at least 2 children before deletion. + if count > 1 { + return b, nil + } + c := b.Children[index] + if index == lastChild { + return c, nil + } + if h, ok := c.(*HashNode); ok { + c, err = t.getFromStore(h.Hash()) + if err != nil { + return nil, err + } + } + if e, ok := c.(*ExtensionNode); ok { + e.key = append([]byte{byte(index)}, e.key...) + e.invalidateCache() + return e, nil + } + + return NewExtensionNode([]byte{byte(index)}, c), nil +} + +func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { + if !bytes.HasPrefix(path, n.key) { + return nil, ErrNotFound + } + r, err := t.deleteFromNode(n.next, path[len(n.key):]) + if err != nil { + return nil, err + } + switch nxt := r.(type) { + case *ExtensionNode: + n.key = append(n.key, nxt.key...) + n.next = nxt.next + case *HashNode: + if nxt.IsEmpty() { + return nxt, nil + } + default: + n.next = r + } + n.invalidateCache() + return n, nil +} + +func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return new(HashNode), nil + } + return nil, ErrNotFound + case *BranchNode: + return t.deleteFromBranch(n, path) + case *ExtensionNode: + return t.deleteFromExtension(n, path) + case *HashNode: + if n.IsEmpty() { + return nil, ErrNotFound + } + newNode, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.deleteFromNode(newNode, path) + default: + panic("invalid MPT node type") + } +} + +// StateRoot returns root hash of t. +func (t *Trie) StateRoot() util.Uint256 { + if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() { + return util.Uint256{} + } + return t.root.Hash() +} + +func makeStorageKey(mptKey []byte) []byte { + return append([]byte{byte(storage.DataMPT)}, mptKey...) +} + +// Flush puts every node in the trie except Hash ones to the storage. +// Because we care only about block-level changes, there is no need to put every +// new node to storage. Normally, flush should be called with every StateRoot persist, i.e. +// after every block. +func (t *Trie) Flush() { + t.flush(t.root) +} + +func (t *Trie) flush(node Node) { + if node.IsFlushed() { + return + } + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + t.flush(n.Children[i]) + } + case *ExtensionNode: + t.flush(n.next) + case *HashNode: + return + } + t.putToStore(node) +} + +func (t *Trie) putToStore(n Node) { + if n.Type() == HashT { + panic("can't put hash node in trie") + } + _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors + n.SetFlushed() +} + +func (t *Trie) getFromStore(h util.Uint256) (Node, error) { + data, err := t.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + return n.Node, nil +} + +// Collapse compresses all nodes at depth n to the hash nodes. +// Note: this function does not perform any kind of storage flushing so +// `Flush()` should be called explicitly before invoking function. +func (t *Trie) Collapse(depth int) { + if depth < 0 { + panic("negative depth") + } + t.root = collapse(depth, t.root) +} + +func collapse(depth int, node Node) Node { + if _, ok := node.(*HashNode); ok { + return node + } else if depth == 0 { + return NewHashNode(node.Hash()) + } + + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + n.Children[i] = collapse(depth-1, n.Children[i]) + } + case *ExtensionNode: + n.next = collapse(depth-1, n.next) + case *LeafNode: + case *HashNode: + default: + panic("invalid MPT node type") + } + return node +} diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go new file mode 100644 index 000000000..d06e08168 --- /dev/null +++ b/pkg/core/mpt/trie_test.go @@ -0,0 +1,446 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/stretchr/testify/require" +) + +func newTestStore() *storage.MemCachedStore { + return storage.NewMemCachedStore(storage.NewMemoryStore()) +} + +func newTestTrie(t *testing.T) *Trie { + b := NewBranchNode() + + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + b.Children[0] = NewExtensionNode([]byte{0x01}, l1) + + l2 := NewLeafNode([]byte{0x22, 0x22}) + b.Children[9] = NewExtensionNode([]byte{0x09}, l2) + + v := NewLeafNode([]byte("hello")) + h := NewHashNode(v.Hash()) + b.Children[10] = NewExtensionNode([]byte{0x0e}, h) + + e := NewExtensionNode(toNibbles([]byte{0xAC}), b) + tr := NewTrie(e, newTestStore()) + + tr.putToStore(e) + tr.putToStore(b) + tr.putToStore(l1) + tr.putToStore(l2) + tr.putToStore(v) + tr.putToStore(b.Children[0]) + tr.putToStore(b.Children[9]) + tr.putToStore(b.Children[10]) + + return tr +} + +func TestTrie_PutIntoBranchNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode([]byte{0x8}) + b.Children[0x7] = NewHashNode(l.Hash()) + b.Children[0x8] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + // next + require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + + // empty hash node child + require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56})) + tr.testHas(t, []byte{0x66}, []byte{0x56}) + require.True(t, isValid(tr.root)) + + // missing hash + require.Error(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) + + // hash is in store + tr.putToStore(l) + require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoExtensionNode(t *testing.T) { + l := NewLeafNode([]byte{0x11}) + key := []byte{0x12} + e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash())) + tr := NewTrie(e, newTestStore()) + + // missing hash + require.Error(t, tr.Put(key, []byte{0x42})) + + tr.putToStore(l) + require.NoError(t, tr.Put(key, []byte{0x42})) + tr.testHas(t, key, []byte{0x42}) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoHashNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode(random.Bytes(5)) + e := NewExtensionNode([]byte{0x02}, l) + b.Children[1] = NewHashNode(e.Hash()) + b.Children[9] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + tr.putToStore(e) + + t.Run("MissingLeafHash", func(t *testing.T) { + _, err := tr.Get([]byte{0x12}) + require.Error(t, err) + }) + + tr.putToStore(l) + + val := random.Bytes(3) + require.NoError(t, tr.Put([]byte{0x12, 0x34}, val)) + tr.testHas(t, []byte{0x12, 0x34}, val) + tr.testHas(t, []byte{0x12}, l.value) + require.True(t, isValid(tr.root)) +} + +func TestTrie_Put(t *testing.T) { + trExp := newTestTrie(t) + + trAct := NewTrie(nil, newTestStore()) + require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD})) + require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22})) + require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello"))) + + // Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie. + require.Equal(t, trExp.root.Hash(), trAct.root.Hash()) + require.True(t, isValid(trAct.root)) +} + +func TestTrie_PutInvalid(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + key, value := []byte("key"), []byte("value") + + // big key + require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value)) + + // big value + require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1))) + + // this is ok though + require.NoError(t, tr.Put(key, value)) + tr.testHas(t, key, value) +} + +func TestTrie_BigPut(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + items := []struct{ k, v string }{ + {"item with long key", "value1"}, + {"item with matching prefix", "value2"}, + {"another prefix", "value3"}, + {"another prefix 2", "value4"}, + {"another ", "value5"}, + } + + for i := range items { + require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v))) + } + + for i := range items { + tr.testHas(t, []byte(items[i].k), []byte(items[i].v)) + } + + t.Run("Rewrite", func(t *testing.T) { + k, v := []byte(items[0].k), []byte{0x01, 0x23} + require.NoError(t, tr.Put(k, v)) + tr.testHas(t, k, v) + }) + + t.Run("Remove", func(t *testing.T) { + k := []byte(items[1].k) + require.NoError(t, tr.Put(k, []byte{})) + tr.testHas(t, k, nil) + }) +} + +func (tr *Trie) testHas(t *testing.T, key, value []byte) { + v, err := tr.Get(key) + if value == nil { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, value, v) +} + +// isValid checks for 3 invariants: +// - BranchNode contains > 1 children +// - ExtensionNode do not contain another extension node +// - ExtensionNode do not have nil key +// It is used only during testing to catch possible bugs. +func isValid(curr Node) bool { + switch n := curr.(type) { + case *BranchNode: + var count int + for i := range n.Children { + if !isValid(n.Children[i]) { + return false + } + hn, ok := n.Children[i].(*HashNode) + if !ok || !hn.IsEmpty() { + count++ + } + } + return count > 1 + case *ExtensionNode: + _, ok := n.next.(*ExtensionNode) + return len(n.key) != 0 && !ok + default: + return true + } +} + +func TestTrie_Get(t *testing.T) { + t.Run("HashNode", func(t *testing.T) { + tr := newTestTrie(t) + tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) + t.Run("UnfoldRoot", func(t *testing.T) { + tr := newTestTrie(t) + single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store) + single.testHas(t, []byte{0xAC}, nil) + single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD}) + single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22}) + single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) +} + +func TestTrie_Flush(t *testing.T) { + pairs := map[string][]byte{ + "": []byte("value0"), + "key1": []byte("value1"), + "key2": []byte("value2"), + } + + tr := NewTrie(nil, newTestStore()) + for k, v := range pairs { + require.NoError(t, tr.Put([]byte(k), v)) + } + + tr.Flush() + tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store) + for k, v := range pairs { + actual, err := tr.Get([]byte(k)) + require.NoError(t, err) + require.Equal(t, v, actual) + } +} + +func TestTrie_Delete(t *testing.T) { + t.Run("Hash", func(t *testing.T) { + t.Run("FromStore", func(t *testing.T) { + l := NewLeafNode([]byte{0x12}) + tr := NewTrie(NewHashNode(l.Hash()), newTestStore()) + t.Run("NotInStore", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + }) + + tr.putToStore(l) + tr.testHas(t, []byte{}, []byte{0x12}) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.Error(t, tr.Delete([]byte{})) + }) + }) + + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + tr := NewTrie(l, newTestStore()) + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{0x12})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + }) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("SingleKey", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + e := NewExtensionNode([]byte{0x0A, 0x0B}, l) + tr := NewTrie(e, newTestStore()) + + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34}) + }) + + require.NoError(t, tr.Delete([]byte{0xAB})) + require.True(t, tr.root.(*HashNode).IsEmpty()) + }) + + t.Run("MultipleKeys", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34})) + b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78})) + e := NewExtensionNode([]byte{0x01, 0x02}, b) + tr := NewTrie(e, newTestStore()) + + h := e.Hash() + require.NoError(t, tr.Delete([]byte{0x12, 0x01})) + tr.testHas(t, []byte{0x12, 0x01}, nil) + tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78}) + + require.NotEqual(t, h, tr.root.Hash()) + require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key) + require.IsType(t, (*LeafNode)(nil), e.next) + }) + }) + + t.Run("Branch", func(t *testing.T) { + t.Run("3 Children", func(t *testing.T) { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34})) + b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56})) + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Delete([]byte{0x16})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x01}, []byte{0x34}) + tr.testHas(t, []byte{0x16}, nil) + }) + t.Run("2 Children", func(t *testing.T) { + newt := func(t *testing.T) *Trie { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + l := NewLeafNode([]byte{0x34}) + e := NewExtensionNode([]byte{0x06}, l) + b.Children[5] = NewHashNode(e.Hash()) + tr := NewTrie(b, newTestStore()) + tr.putToStore(l) + tr.putToStore(e) + return tr + } + + t.Run("DeleteLast", func(t *testing.T) { + t.Run("MergeExtension", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x56}, []byte{0x34}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + + t.Run("LeaveLeaf", func(t *testing.T) { + c := NewBranchNode() + c.Children[5] = NewLeafNode([]byte{0x05}) + c.Children[6] = NewLeafNode([]byte{0x06}) + + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[5] = c + tr := NewTrie(b, newTestStore()) + + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x55}, []byte{0x05}) + tr.testHas(t, []byte{0x56}, []byte{0x06}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + }) + + t.Run("DeleteMiddle", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{0x56})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x56}, nil) + require.IsType(t, (*LeafNode)(nil), tr.root) + }) + }) + }) +} + +func TestTrie_PanicInvalidRoot(t *testing.T) { + tr := &Trie{Store: newTestStore()} + require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) }) + require.Panics(t, func() { _, _ = tr.Get([]byte{1}) }) + require.Panics(t, func() { _ = tr.Delete([]byte{1}) }) +} + +func TestTrie_Collapse(t *testing.T) { + t.Run("PanicNegative", func(t *testing.T) { + tr := newTestTrie(t) + require.Panics(t, func() { tr.Collapse(-1) }) + }) + t.Run("Depth=0", func(t *testing.T) { + tr := newTestTrie(t) + h := tr.root.Hash() + + _, ok := tr.root.(*HashNode) + require.False(t, ok) + + tr.Collapse(0) + _, ok = tr.root.(*HashNode) + require.True(t, ok) + require.Equal(t, h, tr.root.Hash()) + }) + t.Run("Branch,Depth=1", func(t *testing.T) { + b := NewBranchNode() + e := NewExtensionNode([]byte{0x01}, NewLeafNode([]byte("value1"))) + he := e.Hash() + b.Children[0] = e + hb := b.Hash() + + tr := NewTrie(b, newTestStore()) + tr.Collapse(1) + + newb, ok := tr.root.(*BranchNode) + require.True(t, ok) + require.Equal(t, hb, newb.Hash()) + require.IsType(t, (*HashNode)(nil), b.Children[0]) + require.Equal(t, he, b.Children[0].Hash()) + }) + t.Run("Extension,Depth=1", func(t *testing.T) { + l := NewLeafNode([]byte("value")) + hl := l.Hash() + e := NewExtensionNode([]byte{0x01}, l) + h := e.Hash() + tr := NewTrie(e, newTestStore()) + tr.Collapse(1) + + newe, ok := tr.root.(*ExtensionNode) + require.True(t, ok) + require.Equal(t, h, newe.Hash()) + require.IsType(t, (*HashNode)(nil), newe.next) + require.Equal(t, hl, newe.next.Hash()) + }) + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte("value")) + tr := NewTrie(l, newTestStore()) + tr.Collapse(10) + require.Equal(t, NewLeafNode([]byte("value")), tr.root) + }) + t.Run("Hash", func(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(new(HashNode), newTestStore()) + require.NotPanics(t, func() { tr.Collapse(1) }) + hn, ok := tr.root.(*HashNode) + require.True(t, ok) + require.True(t, hn.IsEmpty()) + }) + + h := random.Uint256() + hn := NewHashNode(h) + tr := NewTrie(hn, newTestStore()) + tr.Collapse(10) + + newRoot, ok := tr.root.(*HashNode) + require.True(t, ok) + require.Equal(t, NewHashNode(h), newRoot) + }) +} diff --git a/pkg/core/prometheus.go b/pkg/core/prometheus.go index b81fb847d..c849e3459 100644 --- a/pkg/core/prometheus.go +++ b/pkg/core/prometheus.go @@ -30,6 +30,14 @@ var ( Namespace: "neogo", }, ) + //stateHeight prometheus metric. + stateHeight = prometheus.NewGauge( + prometheus.GaugeOpts{ + Help: "Current verified state height", + Name: "current_state_height", + Namespace: "neogo", + }, + ) ) func init() { @@ -51,3 +59,7 @@ func updateHeaderHeightMetric(hHeight int) { func updateBlockHeightMetric(bHeight uint32) { blockHeight.Set(float64(bHeight)) } + +func updateStateHeightMetric(sHeight uint32) { + stateHeight.Set(float64(sHeight)) +} diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go new file mode 100644 index 000000000..dea3f62fa --- /dev/null +++ b/pkg/core/state/mpt_root.go @@ -0,0 +1,146 @@ +package state + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MPTRootBase represents storage state root. +type MPTRootBase struct { + Version byte `json:"version"` + Index uint32 `json:"index"` + PrevHash util.Uint256 `json:"prehash"` + Root util.Uint256 `json:"stateroot"` +} + +// MPTRoot represents storage state root together with sign info. +type MPTRoot struct { + MPTRootBase + Witness *transaction.Witness `json:"witness,omitempty"` +} + +// MPTRootStateFlag represents verification state of the state root. +type MPTRootStateFlag byte + +// Possible verification states of MPTRoot. +const ( + Unverified MPTRootStateFlag = 0x00 + Verified MPTRootStateFlag = 0x01 + Invalid MPTRootStateFlag = 0x03 +) + +// MPTRootState represents state root together with its verification state. +type MPTRootState struct { + MPTRoot `json:"stateroot"` + Flag MPTRootStateFlag `json:"flag"` +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootState) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(s.Flag)) + s.MPTRoot.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootState) DecodeBinary(r *io.BinReader) { + s.Flag = MPTRootStateFlag(r.ReadB()) + s.MPTRoot.DecodeBinary(r) +} + +// GetSignedPart returns part of MPTRootBase which needs to be signed. +func (s *MPTRootBase) GetSignedPart() []byte { + buf := io.NewBufBinWriter() + s.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} + +// Equals checks if s == other. +func (s *MPTRootBase) Equals(other *MPTRootBase) bool { + return s.Version == other.Version && s.Index == other.Index && + s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root) +} + +// Hash returns hash of s. +func (s *MPTRootBase) Hash() util.Uint256 { + return hash.DoubleSha256(s.GetSignedPart()) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootBase) DecodeBinary(r *io.BinReader) { + s.Version = r.ReadB() + s.Index = r.ReadU32LE() + s.PrevHash.DecodeBinary(r) + s.Root.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) { + w.WriteB(s.Version) + w.WriteU32LE(s.Index) + s.PrevHash.EncodeBinary(w) + s.Root.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRoot) DecodeBinary(r *io.BinReader) { + s.MPTRootBase.DecodeBinary(r) + + var ws []transaction.Witness + r.ReadArray(&ws, 1) + if len(ws) == 1 { + s.Witness = &ws[0] + } +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { + s.MPTRootBase.EncodeBinary(w) + if s.Witness == nil { + w.WriteVarUint(0) + } else { + w.WriteArray([]*transaction.Witness{s.Witness}) + } +} + +// String implements fmt.Stringer. +func (f MPTRootStateFlag) String() string { + switch f { + case Unverified: + return "Unverified" + case Verified: + return "Verified" + case Invalid: + return "Invalid" + default: + return "" + } +} + +// MarshalJSON implements json.Marshaler. +func (f MPTRootStateFlag) MarshalJSON() ([]byte, error) { + return []byte(`"` + f.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (f *MPTRootStateFlag) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + switch s { + case "Unverified": + *f = Unverified + case "Verified": + *f = Verified + case "Invalid": + *f = Invalid + default: + return errors.New("unknown flag") + } + return nil +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go new file mode 100644 index 000000000..f1c0b5c61 --- /dev/null +++ b/pkg/core/state/mpt_root_test.go @@ -0,0 +1,100 @@ +package state + +import ( + "encoding/json" + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func testStateRoot() *MPTRoot { + return &MPTRoot{ + MPTRootBase: MPTRootBase{ + Version: byte(rand.Uint32()), + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + } +} + +func TestStateRoot_Serializable(t *testing.T) { + r := testStateRoot() + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + + t.Run("WithWitness", func(t *testing.T) { + r.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + }) +} + +func TestStateRootEquals(t *testing.T) { + r1 := testStateRoot() + r2 := *r1 + require.True(t, r1.Equals(&r2.MPTRootBase)) + + r2.MPTRootBase.Index++ + require.False(t, r1.Equals(&r2.MPTRootBase)) +} + +func TestMPTRootState_Serializable(t *testing.T) { + rs := &MPTRootState{ + MPTRoot: *testStateRoot(), + Flag: 0x04, + } + rs.MPTRoot.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState)) +} + +func TestMPTRootStateUnverifiedByDefault(t *testing.T) { + var r MPTRootState + require.Equal(t, Unverified, r.Flag) +} + +func TestMPTRoot_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + r := testStateRoot() + rs := &MPTRootState{ + MPTRoot: *r, + Flag: Verified, + } + testserdes.MarshalUnmarshalJSON(t, rs, new(MPTRootState)) + }) + + t.Run("Compatibility", func(t *testing.T) { + js := []byte(`{ + "flag": "Unverified", + "stateroot": { + "version": 1, + "index": 3000000, + "prehash": "0x4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90", + "stateroot": "0xb2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3" + }}`) + + rs := new(MPTRootState) + require.NoError(t, json.Unmarshal(js, &rs)) + + require.EqualValues(t, 1, rs.Version) + require.EqualValues(t, 3000000, rs.Index) + require.Nil(t, rs.Witness) + + u, err := util.Uint256DecodeStringLE("4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90") + require.NoError(t, err) + require.Equal(t, u, rs.PrevHash) + + u, err = util.Uint256DecodeStringLE("b2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3") + require.NoError(t, err) + require.Equal(t, u, rs.Root) + }) +} diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index 73b3866b1..2bb57646d 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -9,6 +9,7 @@ import ( const ( DataBlock KeyPrefix = 0x01 DataTransaction KeyPrefix = 0x02 + DataMPT KeyPrefix = 0x03 STAccount KeyPrefix = 0x40 STCoin KeyPrefix = 0x44 STSpentCoin KeyPrefix = 0x45 diff --git a/pkg/crypto/keys/publickey.go b/pkg/crypto/keys/publickey.go index d10d9c9f4..471d06675 100644 --- a/pkg/crypto/keys/publickey.go +++ b/pkg/crypto/keys/publickey.go @@ -1,7 +1,6 @@ package keys import ( - "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/x509" @@ -10,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/btcsuite/btcd/btcec" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" @@ -18,6 +18,9 @@ import ( "github.com/pkg/errors" ) +// coordLen is the number of bytes in serialized X or Y coordinate. +const coordLen = 32 + // PublicKeys is a list of public keys. type PublicKeys []*PublicKey @@ -94,23 +97,49 @@ func NewPublicKeyFromString(s string) (*PublicKey, error) { return pubKey, nil } -// Bytes returns the byte array representation of the public key. -func (p *PublicKey) Bytes() []byte { +// getBytes serializes X and Y using compressed or uncompressed format. +func (p *PublicKey) getBytes(compressed bool) []byte { if p.IsInfinity() { return []byte{0x00} } - var ( - x = p.X.Bytes() - paddedX = append(bytes.Repeat([]byte{0x00}, 32-len(x)), x...) - prefix = byte(0x03) - ) - - if p.Y.Bit(0) == 0 { - prefix = byte(0x02) + var resLen = 1 + coordLen + if !compressed { + resLen += coordLen } + var res = make([]byte, resLen) + var prefix byte - return append([]byte{prefix}, paddedX...) + xBytes := p.X.Bytes() + copy(res[1+coordLen-len(xBytes):], xBytes) + if compressed { + if p.Y.Bit(0) == 0 { + prefix = 0x02 + } else { + prefix = 0x03 + } + } else { + prefix = 0x04 + yBytes := p.Y.Bytes() + copy(res[1+coordLen+coordLen-len(yBytes):], yBytes) + + } + res[0] = prefix + + return res +} + +// Bytes returns byte array representation of the public key in compressed +// form (33 bytes with 0x02 or 0x03 prefix, except infinity which is always 0). +func (p *PublicKey) Bytes() []byte { + return p.getBytes(true) +} + +// UncompressedBytes returns byte array representation of the public key in +// uncompressed form (65 bytes with 0x04 prefix, except infinity which is +// always 0). +func (p *PublicKey) UncompressedBytes() []byte { + return p.getBytes(false) } // NewPublicKeyFromASN1 returns a NEO PublicKey from the ASN.1 serialized key. @@ -134,15 +163,25 @@ func NewPublicKeyFromASN1(data []byte) (*PublicKey, error) { } // decodeCompressedY performs decompression of Y coordinate for given X and Y's least significant bit. -func decodeCompressedY(x *big.Int, ylsb uint) (*big.Int, error) { - c := elliptic.P256() - cp := c.Params() - three := big.NewInt(3) - /* y**2 = x**3 + a*x + b % p */ - xCubed := new(big.Int).Exp(x, three, cp.P) - threeX := new(big.Int).Mul(x, three) - threeX.Mod(threeX, cp.P) - ySquared := new(big.Int).Sub(xCubed, threeX) +// We use here a short-form Weierstrass curve (https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) +// y² = x³ + ax + b. Two types of elliptic curves are supported: +// 1. Secp256k1 (Koblitz curve): y² = x³ + b, +// 2. Secp256r1 (Random curve): y² = x³ - 3x + b. +// To decode compressed curve point we perform the following operation: y = sqrt(x³ + ax + b mod p) +// where `p` denotes the order of the underlying curve field +func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, error) { + var a *big.Int + switch curve.(type) { + case *btcec.KoblitzCurve: + a = big.NewInt(0) + default: + a = big.NewInt(3) + } + cp := curve.Params() + xCubed := new(big.Int).Exp(x, big.NewInt(3), cp.P) + aX := new(big.Int).Mul(x, a) + aX.Mod(aX, cp.P) + ySquared := new(big.Int).Sub(xCubed, aX) ySquared.Add(ySquared, cp.B) ySquared.Mod(ySquared, cp.P) y := new(big.Int).ModSqrt(ySquared, cp.P) @@ -196,7 +235,7 @@ func (p *PublicKey) DecodeBinary(r *io.BinReader) { } x = new(big.Int).SetBytes(xbytes) ylsb := uint(prefix & 0x1) - y, err = decodeCompressedY(x, ylsb) + y, err = decodeCompressedY(x, ylsb, p256) if err != nil { r.Err = err return @@ -306,3 +345,84 @@ func (p *PublicKey) UnmarshalJSON(data []byte) error { return nil } + +// KeyRecover recovers public key from the given signature (r, s) on the given message hash using given elliptic curve. +// Algorithm source: SEC 1 Ver 2.0, section 4.1.6, pages 47-48 (https://www.secg.org/sec1-v2.pdf). +// Flag isEven denotes Y's least significant bit in decompression algorithm. +func KeyRecover(curve elliptic.Curve, r, s *big.Int, messageHash []byte, isEven bool) (PublicKey, error) { + var ( + res PublicKey + err error + ) + if r.Cmp(big.NewInt(1)) == -1 || s.Cmp(big.NewInt(1)) == -1 { + return res, errors.New("invalid signature") + } + params := curve.Params() + // Calculate h = (Q + 1 + 2 * Sqrt(Q)) / N + // num := new(big.Int).Add(new(big.Int).Add(params.P, big.NewInt(1)), new(big.Int).Mul(big.NewInt(2), new(big.Int).Sqrt(params.P))) + // h := new(big.Int).Div(num, params.N) + // We are skipping this step for secp256k1 and secp256r1 because we know cofactor of these curves (h=1) + // (see section 2.4 of http://www.secg.org/sec2-v2.pdf) + h := 1 + for i := 0; i <= h; i++ { + // Step 1.1: x = (n * i) + r + Rx := new(big.Int).Mul(params.N, big.NewInt(int64(i))) + Rx.Add(Rx, r) + if Rx.Cmp(params.P) == 1 { + break + } + + // Steps 1.2 and 1.3: get point R (Ry) + var R *big.Int + if isEven { + R, err = decodeCompressedY(Rx, 0, curve) + } else { + R, err = decodeCompressedY(Rx, 1, curve) + } + if err != nil { + return res, err + } + + // Step 1.4: check n*R is point at infinity + nRx, nR := curve.ScalarMult(Rx, R, params.N.Bytes()) + if nRx.Sign() != 0 || nR.Sign() != 0 { + continue + } + + // Step 1.5: compute e + e := hashToInt(messageHash, curve) + + // Step 1.6: Q = r^-1 (sR-eG) + invr := new(big.Int).ModInverse(r, params.N) + // First term. + invrS := new(big.Int).Mul(invr, s) + invrS.Mod(invrS, params.N) + sRx, sR := curve.ScalarMult(Rx, R, invrS.Bytes()) + // Second term. + e.Neg(e) + e.Mod(e, params.N) + e.Mul(e, invr) + e.Mod(e, params.N) + minuseGx, minuseGy := curve.ScalarBaseMult(e.Bytes()) + Qx, Qy := curve.Add(sRx, sR, minuseGx, minuseGy) + res.X = Qx + res.Y = Qy + } + return res, nil +} + +// copied from crypto/ecdsa +func hashToInt(hash []byte, c elliptic.Curve) *big.Int { + orderBits := c.Params().N.BitLen() + orderBytes := (orderBits + 7) / 8 + if len(hash) > orderBytes { + hash = hash[:orderBytes] + } + + ret := new(big.Int).SetBytes(hash) + excess := len(hash)*8 - orderBits + if excess > 0 { + ret.Rsh(ret, uint(excess)) + } + return ret +} diff --git a/pkg/crypto/keys/publickey_test.go b/pkg/crypto/keys/publickey_test.go index a9c265e4b..ddee5a332 100644 --- a/pkg/crypto/keys/publickey_test.go +++ b/pkg/crypto/keys/publickey_test.go @@ -1,12 +1,16 @@ package keys import ( + "crypto/elliptic" "encoding/hex" "encoding/json" + "math/big" "math/rand" "sort" "testing" + "github.com/btcsuite/btcd/btcec" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/stretchr/testify/require" ) @@ -85,10 +89,14 @@ func TestPubkeyToAddress(t *testing.T) { func TestDecodeBytes(t *testing.T) { pubKey := getPubKey(t) - decodedPubKey := &PublicKey{} - err := decodedPubKey.DecodeBytes(pubKey.Bytes()) - require.NoError(t, err) - require.Equal(t, pubKey, decodedPubKey) + var testBytesFunction = func(t *testing.T, bytesFunction func() []byte) { + decodedPubKey := &PublicKey{} + err := decodedPubKey.DecodeBytes(bytesFunction()) + require.NoError(t, err) + require.Equal(t, pubKey, decodedPubKey) + } + t.Run("compressed", func(t *testing.T) { testBytesFunction(t, pubKey.Bytes) }) + t.Run("uncompressed", func(t *testing.T) { testBytesFunction(t, pubKey.UncompressedBytes) }) } func TestDecodeBytesBadInfinity(t *testing.T) { @@ -179,3 +187,91 @@ func TestUnmarshallJSONBadFormat(t *testing.T) { err := json.Unmarshal([]byte(str), actual) require.Error(t, err) } + +func TestRecoverSecp256r1(t *testing.T) { + privateKey, err := NewPrivateKey() + require.NoError(t, err) + message := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100} + messageHash := hash.Sha256(message).BytesBE() + signature := privateKey.Sign(message) + r := new(big.Int).SetBytes(signature[0:32]) + s := new(big.Int).SetBytes(signature[32:64]) + require.True(t, privateKey.PublicKey().Verify(signature, messageHash)) + // To test this properly, we should provide correct isEven flag. This flag denotes which one of + // the two recovered R points in decodeCompressedY method should be chosen. Let's suppose that we + // don't know which of them suites, so to test KeyRecover we should check both and only + // one of them gives us the correct public key. + recoveredKeyFalse, err := KeyRecover(elliptic.P256(), r, s, messageHash, false) + require.NoError(t, err) + recoveredKeyTrue, err := KeyRecover(elliptic.P256(), r, s, messageHash, true) + require.NoError(t, err) + require.True(t, privateKey.PublicKey().Equal(&recoveredKeyFalse) != privateKey.PublicKey().Equal(&recoveredKeyTrue)) +} + +func TestRecoverSecp256r1Static(t *testing.T) { + // These data were taken from the reference KeyRecoverTest: https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_ECDsa.cs#L22 + // To update this test, run the reference KeyRecover(ECCurve.Secp256r1) testcase and fetch the following data from it: + // privateKey -> b + // message -> messageHash + // signatures[0] -> r + // signatures[1] -> s + // v -> isEven + // Note, that C# BigInteger has different byte order from that used in Go. + b := []byte{123, 245, 126, 56, 3, 123, 197, 199, 26, 31, 212, 186, 120, 195, 168, 153, 57, 108, 234, 49, 107, 203, 44, 207, 185, 212, 187, 129, 74, 43, 225, 69} + privateKey, err := NewPrivateKeyFromBytes(b) + require.NoError(t, err) + messageHash := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100} + r := new(big.Int).SetBytes([]byte{1, 85, 226, 63, 133, 113, 217, 188, 249, 22, 213, 203, 225, 199, 32, 131, 118, 23, 28, 101, 139, 211, 13, 111, 242, 158, 193, 227, 196, 106, 3, 4}) + s := new(big.Int).SetBytes([]byte{65, 174, 206, 164, 81, 34, 76, 104, 5, 49, 51, 20, 221, 183, 157, 199, 199, 47, 78, 137, 172, 99, 212, 110, 129, 72, 236, 59, 250, 81, 200, 13}) + // Just ensure it's a valid signature. + require.True(t, privateKey.PublicKey().Verify(append(r.Bytes(), s.Bytes()...), messageHash)) + recoveredKey, err := KeyRecover(elliptic.P256(), r, s, messageHash, false) + require.NoError(t, err) + require.True(t, privateKey.PublicKey().Equal(&recoveredKey)) +} + +func TestRecoverSecp256k1(t *testing.T) { + privateKey, err := btcec.NewPrivateKey(btcec.S256()) + message := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100} + signature, err := privateKey.Sign(message) + require.NoError(t, err) + require.True(t, signature.Verify(message, privateKey.PubKey())) + // To test this properly, we should provide correct isEven flag. This flag denotes which one of + // the two recovered R points in decodeCompressedY method should be chosen. Let's suppose that we + // don't know which of them suites, so to test KeyRecover we should check both and only + // one of them gives us the correct public key. + recoveredKeyFalse, err := KeyRecover(btcec.S256(), signature.R, signature.S, message, false) + require.NoError(t, err) + recoveredKeyTrue, err := KeyRecover(btcec.S256(), signature.R, signature.S, message, true) + require.NoError(t, err) + require.True(t, (privateKey.PubKey().X.Cmp(recoveredKeyFalse.X) == 0 && + privateKey.PubKey().Y.Cmp(recoveredKeyFalse.Y) == 0) != + (privateKey.PubKey().X.Cmp(recoveredKeyTrue.X) == 0 && + privateKey.PubKey().Y.Cmp(recoveredKeyTrue.Y) == 0)) +} + +func TestRecoverSecp256k1Static(t *testing.T) { + // These data were taken from the reference testcase: https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_ECDsa.cs#L22 + // To update this test, run the reference KeyRecover(ECCurve.Secp256k1) testcase and fetch the following data from it: + // privateKey -> b + // message -> messageHash + // signatures[0] -> r + // signatures[1] -> s + // v -> isEven + // Note, that C# BigInteger has different byte order from that used in Go. + b := []byte{156, 3, 247, 58, 246, 250, 236, 27, 118, 60, 180, 177, 18, 92, 204, 206, 144, 245, 148, 141, 86, 212, 151, 181, 15, 113, 172, 180, 177, 228, 100, 32} + _, publicKey := btcec.PrivKeyFromBytes(btcec.S256(), b) + messageHash := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100} + r := new(big.Int).SetBytes([]byte{88, 169, 242, 111, 210, 184, 180, 46, 67, 108, 176, 77, 57, 250, 58, 36, 110, 81, 225, 65, 90, 47, 215, 91, 27, 227, 57, 6, 9, 228, 100, 50}) + s := new(big.Int).SetBytes([]byte{86, 150, 81, 190, 17, 181, 212, 241, 184, 36, 136, 116, 232, 207, 46, 45, 149, 167, 15, 98, 113, 137, 66, 98, 214, 165, 38, 232, 98, 96, 79, 197}) + signature := btcec.Signature{ + R: r, + S: s, + } + // Just ensure it's a valid signature. + require.True(t, signature.Verify(messageHash, publicKey)) + recoveredKey, err := KeyRecover(btcec.S256(), r, s, messageHash, false) + require.NoError(t, err) + require.True(t, new(big.Int).SetBytes([]byte{112, 186, 29, 131, 169, 21, 212, 95, 81, 172, 201, 145, 168, 108, 129, 90, 6, 111, 80, 39, 136, 157, 15, 181, 98, 108, 133, 108, 144, 80, 23, 225}).Cmp(recoveredKey.X) == 0) + require.True(t, new(big.Int).SetBytes([]byte{187, 102, 202, 42, 152, 133, 222, 55, 137, 228, 154, 80, 182, 35, 133, 14, 55, 165, 36, 64, 178, 55, 13, 112, 224, 143, 66, 143, 208, 18, 2, 211}).Cmp(recoveredKey.Y) == 0) +} diff --git a/pkg/interop/crypto/crypto.go b/pkg/interop/crypto/crypto.go index 5dbee46ed..02fce7c5a 100644 --- a/pkg/interop/crypto/crypto.go +++ b/pkg/interop/crypto/crypto.go @@ -1,5 +1,5 @@ /* -Package crypto provides an interface to VM cryptographic instructions. +Package crypto provides an interface to VM cryptographic instructions and syscalls. */ package crypto @@ -30,3 +30,23 @@ func Hash256(b []byte) []byte { func VerifySignature(msg []byte, sig []byte, pub []byte) bool { return false } + +// Secp256k1Recover recovers public key from the given signature (r, s) on the +// given message hash using Secp256k1 elliptic curve. Flag isEven denotes Y's +// least significant bit in decompression algorithm. The return value is byte +// array representation of the public key which is either empty (if it's not +// possible to recover key) or contains 32 bytes in BE for X point (in case of +// success). This function uses Neo.Cryptography.Secp256k1Recover syscall. +func Secp256k1Recover(r []byte, s []byte, messageHash []byte, isEven bool) []byte { + return nil +} + +// Secp256r1Recover recovers public key from the given signature (r, s) on the +// given message hash using Secp256r1 elliptic curve. Flag isEven denotes Y's +// least significant bit in decompression algorithm. The return value is byte +// array representation of the public key which is either empty (if it's not +// possible to recover key) or contains 32 bytes in BE for X point (in case of +// success). This function uses Neo.Cryptography.Secp256r1Recover syscall. +func Secp256r1Recover(r []byte, s []byte, messageHash []byte, isEven bool) []byte { + return nil +} diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index fd23355a2..b8c935c80 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -8,9 +8,9 @@ import ( "reflect" ) -// maxArraySize is a maximums size of an array which can be decoded. +// MaxArraySize is the maximum size of an array which can be decoded. // It is taken from https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs#L130 -const maxArraySize = 0x1000000 +const MaxArraySize = 0x1000000 // BinReader is a convenient wrapper around a io.Reader and err object. // Used to simplify error handling when reading into a struct with many fields. @@ -110,7 +110,7 @@ func (r *BinReader) ReadArray(t interface{}, maxSize ...int) { elemType := sliceType.Elem() isPtr := elemType.Kind() == reflect.Ptr - ms := maxArraySize + ms := MaxArraySize if len(maxSize) != 0 { ms = maxSize[0] } @@ -168,8 +168,16 @@ func (r *BinReader) ReadVarUint() uint64 { // ReadVarBytes reads the next set of bytes from the underlying reader. // ReadVarUInt() is used to determine how large that slice is -func (r *BinReader) ReadVarBytes() []byte { +func (r *BinReader) ReadVarBytes(maxSize ...int) []byte { n := r.ReadVarUint() + ms := MaxArraySize + if len(maxSize) != 0 { + ms = maxSize[0] + } + if n > uint64(ms) { + r.Err = fmt.Errorf("byte-slice is too big (%d)", n) + return nil + } b := make([]byte, n) r.ReadBytes(b) return b diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index d5e1cf8c6..fd998d503 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -143,6 +143,35 @@ func TestBufBinWriter_Len(t *testing.T) { require.Equal(t, 1, bw.Len()) } +func TestBinReader_ReadVarBytes(t *testing.T) { + buf := make([]byte, 11) + for i := range buf { + buf[i] = byte(i) + } + w := NewBufBinWriter() + w.WriteVarBytes(buf) + require.NoError(t, w.Err) + data := w.Bytes() + + t.Run("NoArguments", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + actual := r.ReadVarBytes() + require.NoError(t, r.Err) + require.Equal(t, buf, actual) + }) + t.Run("Good", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + actual := r.ReadVarBytes(11) + require.NoError(t, r.Err) + require.Equal(t, buf, actual) + }) + t.Run("Bad", func(t *testing.T) { + r := NewBinReaderFromBuf(data) + r.ReadVarBytes(10) + require.Error(t, r.Err) + }) +} + func TestWriterErrHandling(t *testing.T) { var badio = &badRW{} bw := NewBinWriterFromIO(badio) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index a719d012d..0bc0b6637 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -59,6 +59,9 @@ func (chain *testChain) AddBlock(block *block.Block) error { } return nil } +func (chain *testChain) AddStateRoot(r *state.MPTRoot) error { + panic("TODO") +} func (chain *testChain) BlockHeight() uint32 { return atomic.LoadUint32(&chain.blockheight) } @@ -105,6 +108,12 @@ func (chain testChain) GetEnrollments() ([]*state.Validator, error) { func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) { panic("TODO") } +func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) { + panic("TODO") +} +func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + panic("TODO") +} func (chain testChain) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem { panic("TODO") } @@ -145,7 +154,9 @@ func (chain testChain) IsLowPriority(util.Fixed8) bool { func (chain testChain) PoolTx(*transaction.Transaction) error { panic("TODO") } - +func (chain testChain) StateHeight() uint32 { + panic("TODO") +} func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) { panic("TODO") } diff --git a/pkg/network/message.go b/pkg/network/message.go index a8bedc96c..f17b62658 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" @@ -59,12 +60,15 @@ const ( CMDGetBlocks CommandType = "getblocks" CMDGetData CommandType = "getdata" CMDGetHeaders CommandType = "getheaders" + CMDGetRoots CommandType = "getroots" CMDHeaders CommandType = "headers" CMDInv CommandType = "inv" CMDMempool CommandType = "mempool" CMDMerkleBlock CommandType = "merkleblock" CMDPing CommandType = "ping" CMDPong CommandType = "pong" + CMDRoots CommandType = "roots" + CMDStateRoot CommandType = "stateroot" CMDTX CommandType = "tx" CMDUnknown CommandType = "unknown" CMDVerack CommandType = "verack" @@ -124,6 +128,8 @@ func (m *Message) CommandType() CommandType { return CMDGetData case "getheaders": return CMDGetHeaders + case "getroots": + return CMDGetRoots case "headers": return CMDHeaders case "inv": @@ -136,6 +142,10 @@ func (m *Message) CommandType() CommandType { return CMDPing case "pong": return CMDPong + case "roots": + return CMDRoots + case "stateroot": + return CMDStateRoot case "tx": return CMDTX case "verack": @@ -191,6 +201,8 @@ func (m *Message) decodePayload(br *io.BinReader) error { fallthrough case CMDGetHeaders: p = &payload.GetBlocks{} + case CMDGetRoots: + p = &payload.GetStateRoots{} case CMDHeaders: p = &payload.Headers{} case CMDTX: @@ -199,6 +211,10 @@ func (m *Message) decodePayload(br *io.BinReader) error { p = &payload.MerkleBlock{} case CMDPing, CMDPong: p = &payload.Ping{} + case CMDRoots: + p = &payload.StateRoots{} + case CMDStateRoot: + p = &state.MPTRoot{} default: return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command)) } diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index d582e0486..fd5f9ed71 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -18,6 +18,8 @@ func (i InventoryType) String() string { return "TX" case 0x02: return "block" + case StateRootType: + return "stateroot" case 0xe0: return "consensus" default: @@ -27,13 +29,14 @@ func (i InventoryType) String() string { // Valid returns true if the inventory (type) is known. func (i InventoryType) Valid() bool { - return i == BlockType || i == TXType || i == ConsensusType + return i == BlockType || i == TXType || i == ConsensusType || i == StateRootType } // List of valid InventoryTypes. const ( TXType InventoryType = 0x01 // 1 BlockType InventoryType = 0x02 // 2 + StateRootType InventoryType = 0x03 // 3 ConsensusType InventoryType = 0xe0 // 224 ) diff --git a/pkg/network/payload/state_root.go b/pkg/network/payload/state_root.go new file mode 100644 index 000000000..f43584375 --- /dev/null +++ b/pkg/network/payload/state_root.go @@ -0,0 +1,43 @@ +package payload + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MaxStateRootsAllowed is a maxumum amount of state roots +// which can be sent in a single payload. +const MaxStateRootsAllowed = 2000 + +// StateRoots contains multiple StateRoots. +type StateRoots struct { + Roots []state.MPTRoot +} + +// GetStateRoots represents request for state roots. +type GetStateRoots struct { + Start uint32 + Count uint32 +} + +// EncodeBinary implements io.Serializable. +func (s *StateRoots) EncodeBinary(w *io.BinWriter) { + w.WriteArray(s.Roots) +} + +// DecodeBinary implements io.Serializable. +func (s *StateRoots) DecodeBinary(r *io.BinReader) { + r.ReadArray(&s.Roots, MaxStateRootsAllowed) +} + +// DecodeBinary implements io.Serializable. +func (g *GetStateRoots) DecodeBinary(r *io.BinReader) { + g.Start = r.ReadU32LE() + g.Count = r.ReadU32LE() +} + +// EncodeBinary implements io.Serializable. +func (g *GetStateRoots) EncodeBinary(w *io.BinWriter) { + w.WriteU32LE(g.Start) + w.WriteU32LE(g.Count) +} diff --git a/pkg/network/payload/state_root_test.go b/pkg/network/payload/state_root_test.go new file mode 100644 index 000000000..a3f670713 --- /dev/null +++ b/pkg/network/payload/state_root_test.go @@ -0,0 +1,51 @@ +package payload + +import ( + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" +) + +func TestStateRoots_Serializable(t *testing.T) { + expected := &StateRoots{ + Roots: []state.MPTRoot{ + { + MPTRootBase: state.MPTRootBase{ + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + Witness: &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + }, + { + MPTRootBase: state.MPTRootBase{ + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + Witness: &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + }, + }, + } + + testserdes.EncodeDecodeBinary(t, expected, new(StateRoots)) +} + +func TestGetStateRoots_Serializable(t *testing.T) { + expected := &GetStateRoots{ + Start: rand.Uint32(), + Count: rand.Uint32(), + } + + testserdes.EncodeDecodeBinary(t, expected, new(GetStateRoots)) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 1836cdf92..9364da476 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -13,6 +13,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/cache" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" @@ -28,6 +30,7 @@ const ( maxBlockBatch = 200 maxAddrsToSend = 200 minPoolCount = 30 + stateRootCacheSize = 100 ) var ( @@ -66,6 +69,7 @@ type ( transactions chan *transaction.Transaction + stateCache cache.HashCache consensusStarted *atomic.Bool log *zap.Logger @@ -98,6 +102,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (* unregister: make(chan peerDrop), peers: make(map[Peer]bool), consensusStarted: atomic.NewBool(false), + stateCache: *cache.NewFIFOCache(stateRootCacheSize), log: log, transactions: make(chan *transaction.Transaction, 64), } @@ -469,6 +474,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { cp := s.consensus.GetPayload(h) return cp != nil }, + payload.StateRootType: s.stateCache.Has, } if exists := typExists[inv.Type]; exists != nil { for _, hash := range inv.Hashes { @@ -507,6 +513,11 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { if err == nil { msg = s.MkMsg(CMDBlock, b) } + case payload.StateRootType: + r := s.stateCache.Get(hash) + if r != nil { + msg = s.MkMsg(CMDStateRoot, r.(*state.MPTRoot)) + } case payload.ConsensusType: if cp := s.consensus.GetPayload(hash); cp != nil { msg = s.MkMsg(CMDConsensus, cp) @@ -589,6 +600,87 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { return p.EnqueueP2PMessage(msg) } +// handleGetRootsCmd processees `getroots` request. +func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { + cfg := s.chain.GetConfig() + if !cfg.EnableStateRoot || gr.Start < cfg.StateRootEnableIndex { + return nil + } + count := gr.Count + if count > payload.MaxStateRootsAllowed { + count = payload.MaxStateRootsAllowed + } + var rs payload.StateRoots + for height := gr.Start; height < gr.Start+gr.Count; height++ { + r, err := s.chain.GetStateRoot(height) + if err != nil { + return err + } else if r.Flag == state.Verified { + rs.Roots = append(rs.Roots, r.MPTRoot) + } + } + msg := s.MkMsg(CMDRoots, &rs) + return p.EnqueueP2PMessage(msg) +} + +// handleStateRootsCmd processees `roots` request. +func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error { + if !s.chain.GetConfig().EnableStateRoot { + return nil + } + h := s.chain.StateHeight() + if h < s.chain.GetConfig().StateRootEnableIndex { + h = s.chain.GetConfig().StateRootEnableIndex + } + for i := range rs.Roots { + if rs.Roots[i].Index <= h { + continue + } + _ = s.chain.AddStateRoot(&rs.Roots[i]) + } + // request more state roots from peer if needed + return s.requestStateRoot(p) +} + +// requestStateRoot sends `getroots` message to get verified state roots. +func (s *Server) requestStateRoot(p Peer) error { + stateHeight := s.chain.StateHeight() + hdrHeight := s.chain.BlockHeight() + enableIndex := s.chain.GetConfig().StateRootEnableIndex + if hdrHeight < enableIndex { + return nil + } + if stateHeight < enableIndex { + stateHeight = enableIndex - 1 + } + count := uint32(payload.MaxStateRootsAllowed) + if diff := hdrHeight - stateHeight; diff < count { + count = diff + } + if count == 0 { + return nil + } + gr := &payload.GetStateRoots{ + Start: stateHeight + 1, + Count: count, + } + return p.EnqueueP2PMessage(s.MkMsg(CMDGetRoots, gr)) +} + +// handleStateRootCmd processees `stateroot` request. +func (s *Server) handleStateRootCmd(r *state.MPTRoot) error { + if !s.chain.GetConfig().EnableStateRoot { + return nil + } + // we ignore error, because there is nothing wrong if we already have this state root + err := s.chain.AddStateRoot(r) + if err == nil && !s.stateCache.Has(r.Hash()) { + s.stateCache.Add(r) + s.broadcastMessage(s.MkMsg(CMDStateRoot, r)) + } + return nil +} + // handleConsensusCmd processes received consensus payload. // It never returns an error. func (s *Server) handleConsensusCmd(cp *consensus.Payload) error { @@ -697,6 +789,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetHeaders: gh := msg.Payload.(*payload.GetBlocks) return s.handleGetHeadersCmd(peer, gh) + case CMDGetRoots: + gr := msg.Payload.(*payload.GetStateRoots) + return s.handleGetRootsCmd(peer, gr) case CMDHeaders: headers := msg.Payload.(*payload.Headers) go s.handleHeadersCmd(peer, headers) @@ -718,6 +813,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDPong: pong := msg.Payload.(*payload.Ping) return s.handlePong(peer, pong) + case CMDRoots: + rs := msg.Payload.(*payload.StateRoots) + return s.handleRootsCmd(peer, rs) + case CMDStateRoot: + r := msg.Payload.(*state.MPTRoot) + return s.handleStateRootCmd(r) case CMDVersion, CMDVerack: return fmt.Errorf("received '%s' after the handshake", msg.CommandType()) } @@ -741,11 +842,20 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return nil } -func (s *Server) handleNewPayload(p *consensus.Payload) { - msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) - // It's high priority because it directly affects consensus process, - // even though it's just an inv. - s.broadcastHPMessage(msg) +func (s *Server) handleNewPayload(item cache.Hashable) { + switch p := item.(type) { + case *consensus.Payload: + msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) + // It's high priority because it directly affects consensus process, + // even though it's just an inv. + s.broadcastHPMessage(msg) + case *state.MPTRoot: + s.stateCache.Add(p) + msg := s.MkMsg(CMDStateRoot, p) + s.broadcastMessage(msg) + default: + s.log.Warn("unknown item type", zap.String("type", fmt.Sprintf("%T", p))) + } } func (s *Server) requestTx(hashes ...util.Uint256) { diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index db1c13bc4..5643c335d 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -251,6 +251,9 @@ func (p *TCPPeer) StartProtocol() { if p.LastBlockIndex() > p.server.chain.BlockHeight() { err = p.server.requestBlocks(p) } + if err == nil && p.server.chain.GetConfig().EnableStateRoot { + err = p.server.requestStateRoot(p) + } if err == nil { timer.Reset(p.server.ProtoTickInterval) } diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 1eebe08bd..b548e74a3 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -181,8 +181,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -196,8 +196,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.NotificationFilterT, param.Type) filt, ok := param.Value.(request.NotificationFilter) require.Equal(t, true, ok) @@ -211,8 +211,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.ExecutionFilterT, param.Type) filt, ok := param.Value.(request.ExecutionFilter) require.Equal(t, true, ok) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 42159c336..205a2a93e 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -61,12 +61,17 @@ const ( ExecutionFilterT ) +var errMissingParameter = errors.New("parameter is missing") + func (p Param) String() string { return fmt.Sprintf("%v", p.Value) } // GetString returns string value of the parameter. -func (p Param) GetString() (string, error) { +func (p *Param) GetString() (string, error) { + if p == nil { + return "", errMissingParameter + } str, ok := p.Value.(string) if !ok { return "", errors.New("not a string") @@ -74,8 +79,26 @@ func (p Param) GetString() (string, error) { return str, nil } +// GetBoolean returns boolean value of the parameter. +func (p *Param) GetBoolean() bool { + if p == nil { + return false + } + switch p.Type { + case NumberT: + return p.Value != 0 + case StringT: + return p.Value != "" + default: + return true + } +} + // GetInt returns int value of te parameter. -func (p Param) GetInt() (int, error) { +func (p *Param) GetInt() (int, error) { + if p == nil { + return 0, errMissingParameter + } i, ok := p.Value.(int) if ok { return i, nil @@ -86,7 +109,10 @@ func (p Param) GetInt() (int, error) { } // GetArray returns a slice of Params stored in the parameter. -func (p Param) GetArray() ([]Param, error) { +func (p *Param) GetArray() ([]Param, error) { + if p == nil { + return nil, errMissingParameter + } a, ok := p.Value.([]Param) if !ok { return nil, errors.New("not an array") @@ -95,7 +121,7 @@ func (p Param) GetArray() ([]Param, error) { } // GetUint256 returns Uint256 value of the parameter. -func (p Param) GetUint256() (util.Uint256, error) { +func (p *Param) GetUint256() (util.Uint256, error) { s, err := p.GetString() if err != nil { return util.Uint256{}, err @@ -105,7 +131,7 @@ func (p Param) GetUint256() (util.Uint256, error) { } // GetUint160FromHex returns Uint160 value of the parameter encoded in hex. -func (p Param) GetUint160FromHex() (util.Uint160, error) { +func (p *Param) GetUint160FromHex() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -119,7 +145,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) { // GetUint160FromAddress returns Uint160 value of the parameter that was // supplied as an address. -func (p Param) GetUint160FromAddress() (util.Uint160, error) { +func (p *Param) GetUint160FromAddress() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -129,7 +155,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) { } // GetFuncParam returns current parameter as a function call parameter. -func (p Param) GetFuncParam() (FuncParam, error) { +func (p *Param) GetFuncParam() (FuncParam, error) { + if p == nil { + return FuncParam{}, errMissingParameter + } fp, ok := p.Value.(FuncParam) if !ok { return FuncParam{}, errors.New("not a function parameter") @@ -139,7 +168,7 @@ func (p Param) GetFuncParam() (FuncParam, error) { // GetBytesHex returns []byte value of the parameter if // it is a hex-encoded string. -func (p Param) GetBytesHex() ([]byte, error) { +func (p *Param) GetBytesHex() ([]byte, error) { s, err := p.GetString() if err != nil { return nil, err @@ -163,6 +192,11 @@ func (p *Param) UnmarshalJSON(data []byte) error { {ArrayT, &[]Param{}}, } + if bytes.Equal(data, []byte("null")) { + p.Type = defaultT + return nil + } + for _, cur := range attempts { r := bytes.NewReader(data) jd := json.NewDecoder(r) diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index da04ea540..7bf2ae22d 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -14,7 +14,7 @@ import ( ) func TestParam_UnmarshalJSON(t *testing.T) { - msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}], + msg := `["str1", 123, null, ["str2", 3], [{"type": "String", "value": "jajaja"}], {"type": "MinerTransaction"}, {"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, {"state": "HALT"}]` @@ -29,6 +29,9 @@ func TestParam_UnmarshalJSON(t *testing.T) { Type: NumberT, Value: 123, }, + { + Type: defaultT, + }, { Type: ArrayT, Value: []Param{ diff --git a/pkg/rpc/request/params.go b/pkg/rpc/request/params.go index 8b1945cb1..dd2ac35b9 100644 --- a/pkg/rpc/request/params.go +++ b/pkg/rpc/request/params.go @@ -7,20 +7,19 @@ type ( // Value returns the param struct for the given // index if it exists. -func (p Params) Value(index int) (*Param, bool) { +func (p Params) Value(index int) *Param { if len(p) > index { - return &p[index], true + return &p[index] } - return nil, false + return nil } // ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { - if val, ok := p.Value(index); ok && val.Type == valType { - return val, true +func (p Params) ValueWithType(index int, valType paramType) *Param { + if val := p.Value(index); val != nil && val.Type == valType { + return val } - - return nil, false + return nil } diff --git a/pkg/rpc/response/result/mpt.go b/pkg/rpc/response/result/mpt.go new file mode 100644 index 000000000..10ef7e8c3 --- /dev/null +++ b/pkg/rpc/response/result/mpt.go @@ -0,0 +1,122 @@ +package result + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// StateHeight is a result of getstateheight RPC. +type StateHeight struct { + BlockHeight uint32 `json:"blockHeight"` + StateHeight uint32 `json:"stateHeight"` +} + +// ProofWithKey represens key-proof pair. +type ProofWithKey struct { + Key []byte + Proof [][]byte +} + +// GetProof is a result of getproof RPC. +type GetProof struct { + Result ProofWithKey `json:"proof"` + Success bool `json:"success"` +} + +// VerifyProof is a result of verifyproof RPC. +// nil Value is considered invalid. +type VerifyProof struct { + Value []byte +} + +// MarshalJSON implements json.Marshaler. +func (p *ProofWithKey) MarshalJSON() ([]byte, error) { + w := io.NewBufBinWriter() + p.EncodeBinary(w.BinWriter) + if w.Err != nil { + return nil, w.Err + } + return []byte(`"` + hex.EncodeToString(w.Bytes()) + `"`), nil +} + +// EncodeBinary implements io.Serializable. +func (p *ProofWithKey) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(p.Key) + w.WriteVarUint(uint64(len(p.Proof))) + for i := range p.Proof { + w.WriteVarBytes(p.Proof[i]) + } +} + +// DecodeBinary implements io.Serializable. +func (p *ProofWithKey) DecodeBinary(r *io.BinReader) { + p.Key = r.ReadVarBytes() + sz := r.ReadVarUint() + for i := uint64(0); i < sz; i++ { + p.Proof = append(p.Proof, r.ReadVarBytes()) + } +} + +// UnmarshalJSON implements json.Unmarshaler. +func (p *ProofWithKey) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + return p.FromString(s) +} + +// String implements fmt.Stringer. +func (p *ProofWithKey) String() string { + w := io.NewBufBinWriter() + p.EncodeBinary(w.BinWriter) + return hex.EncodeToString(w.Bytes()) +} + +// FromString decodes p from hex-encoded string. +func (p *ProofWithKey) FromString(s string) error { + rawProof, err := hex.DecodeString(s) + if err != nil { + return err + } + r := io.NewBinReaderFromBuf(rawProof) + p.DecodeBinary(r) + return r.Err +} + +// MarshalJSON implements json.Marshaler. +func (p *VerifyProof) MarshalJSON() ([]byte, error) { + if p.Value == nil { + return []byte(`"invalid"`), nil + } + return []byte(`{"value":"` + hex.EncodeToString(p.Value) + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (p *VerifyProof) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, []byte(`"invalid"`)) { + p.Value = nil + return nil + } + var m map[string]string + if err := json.Unmarshal(data, &m); err != nil { + return err + } + if len(m) != 1 { + return errors.New("must have single key") + } + v, ok := m["value"] + if !ok { + return errors.New("invalid json") + } + b, err := hex.DecodeString(v) + if err != nil { + return err + } + p.Value = b + return nil +} diff --git a/pkg/rpc/response/result/mpt_test.go b/pkg/rpc/response/result/mpt_test.go new file mode 100644 index 000000000..22e0c021c --- /dev/null +++ b/pkg/rpc/response/result/mpt_test.go @@ -0,0 +1,68 @@ +package result + +import ( + "encoding/json" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/require" +) + +func testProofWithKey() *ProofWithKey { + return &ProofWithKey{ + Key: random.Bytes(10), + Proof: [][]byte{ + random.Bytes(12), + random.Bytes(0), + random.Bytes(34), + }, + } +} + +func TestGetProof_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + p := &GetProof{ + Result: *testProofWithKey(), + Success: true, + } + testserdes.MarshalUnmarshalJSON(t, p, new(GetProof)) + }) + t.Run("Compatibility", func(t *testing.T) { + js := []byte(`{ + "proof" : "25ddeb9aa1bfc353c9c54e21dffb470f65d9c22a0662616c616e63654f70000000000000000708fd12020020666eaa8a6e75d43a97d76e72b605c7e05189f0c57ec19d84acdb75810f18239d202c83028ce3d7abcf4e4f95d05fbfdfa5e18bde3a8fbb65a57559d6b5ea09425c2090c40d440744a848e3b407a00e4efb692a957245a1efc9cb8496cb05fd328ee620dd2652bf25dfc3ad5fee7b200ccf3e3ae50772ff8ed58907e4dab8e7d4b2489720d8a5d5ed75b5b0f256d0a2cf5c220b4ddae2a228ef0fc0212b689f3811dfa94620342cc0d73fabd2440ed2cc735a9608391a510e1981b321a9f4258682706adc9620ced036e52f39387b9c58ade7bf8c3ca8959b64d8031d36d9b1c62f3f1c51c7cb2031072c7c801b5c1614dae441383a65344acd238f13db28ff0a39c0626e597f002062552d64c616d8b2a6a93d22936055110c0065728aa2b4fbf4d76b108390b474203322d3c93c741674a307cf6455e77c02ceeda307d4ec23fd809a2a420b4243f82052ab92a9cedc6716ad4c66a8a3e423b195b05bdebde456f992bff48f2561e99720e6379995e7053823b8ba8fb8af9623cf48e89f60c989598445df5e711db42a6f20192894ed637e86561ff6a4b8dea4539dee8bddb2fb20bf4ae3499852985c88b120e0005edd09f2335aa6b59ff4723e1262b2192adaa5e3e56f79e662f07041f04c2033577f3e2c5bb0e58746980a07cdfad2f872e2b9a10bcc27b7c678c85576df8420f0f04180d15b6eaa0c43e62380084c75ad773d790700a7120c6c4da1fc51693000fd720100209648e8f10a5ff4c209009b9a09697babbe1b2150d0948c1970a560282a1bfa4720988af8f34859dd8309bffea0b1dff9c8cef0b9b0d6a1852d40786627729ae7be00206ebf4f1b7861bca041cbb8feca75158511ca43a1810d17e1e3017468e8cef0de20cac93064090a7da09f8202c17d1e6cbb9a16eb43afcb032e80719cbf05b3446d2019b76a10b91fb99ec08814e8108e5490b879fb09a190cb2c129dfd98335bd5de000020b1da1198bacacf2adc0d863929d77c285ce3a26e736203d0c0a69a1312255fb2207ee8aa092f49348bd89f9c4bf004b0bee2241a2d0acfe7b3ce08e414b04a5717205b0dda71eac8a4e4cdc6a7b939748c0a78abb54f2547a780e6df67b25530330f000020fc358fb9d1e0d36461e015ac8e35f97072a9f9e750a3c25722a2b1a858fcb82d203c52c9fac6d4694b351390158334a9166bc3478ceb9bea2b0b244915f918239e20d526344a24ff19ee6a9f5c5beb833f4eb6d51191590350e26fa50b138493473f005200000000000000000000002077c404fec0a4265568951dbd096572787d109fab105213f4f292a5f53ce72fca00000020b8d1c7a386eaba83ce83ee0700d4ca9b86e75d147d670ea05123e438231d895000004801250b090a0a010b0f0c0305030c090c05040e02010d0f0f0b0407000f06050d090c02020a0006202af2097cf9d3f42e49f6b3c3dd254e7cbdab3485b029721cbbbf1ad0455a810852000000000000002055170506f4b18bc573a909b51cb21bdd5d303ec511f6cdfb1c6a1ab8d8a1dad020ee774c1b9fe1d8ea8d05823837d959da48af74f384d52f06c42c9d146c5258e300000000000000000072000000204457a6fe530ee953ad1f9caf63daf7f86719c9986df2d0b6917021eb379800f00020406bfc79da4ba6f37452a679d13cca252585d34f7e94a480b047bad9427f233e00000000201ce15a2373d28e0dc5f2000cf308f155d06f72070a29e5af1528c8f05f29d248000000000000004301200601060c0601060e06030605040f0700000000000000000000000000000000072091b83866bbd7450115b462e8d48601af3c3e9a35e7018d2b98a23e107c15c200090307000410a328e800", + "success" : true + }`) + + var p GetProof + require.NoError(t, json.Unmarshal(js, &p)) + require.Equal(t, 8, len(p.Result.Proof)) + for i := range p.Result.Proof { // smoke test that every chunk is correctly encoded node + r := io.NewBinReaderFromBuf(p.Result.Proof[i]) + var n mpt.NodeObject + n.DecodeBinary(r) + require.NoError(t, r.Err) + require.NotNil(t, n.Node) + } + }) +} + +func TestProofWithKey_EncodeString(t *testing.T) { + expected := testProofWithKey() + var actual ProofWithKey + require.NoError(t, actual.FromString(expected.String())) + require.Equal(t, expected, &actual) +} + +func TestVerifyProof_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + vp := &VerifyProof{random.Bytes(100)} + testserdes.MarshalUnmarshalJSON(t, vp, new(VerifyProof)) + }) + t.Run("NoValue", func(t *testing.T) { + vp := new(VerifyProof) + testserdes.MarshalUnmarshalJSON(t, vp, &VerifyProof{[]byte{1, 2, 3}}) + }) +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index a91502e5e..675efc9b7 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -95,6 +96,9 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "getpeers": (*Server).getPeers, "getrawmempool": (*Server).getRawMempool, "getrawtransaction": (*Server).getrawtransaction, + "getproof": (*Server).getProof, + "getstateheight": (*Server).getStateHeight, + "getstateroot": (*Server).getStateRoot, "getstorage": (*Server).getStorage, "gettransactionheight": (*Server).getTransactionHeight, "gettxout": (*Server).getTxOut, @@ -108,6 +112,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "sendrawtransaction": (*Server).sendrawtransaction, "submitblock": (*Server).submitBlock, "validateaddress": (*Server).validateAddress, + "verifyproof": (*Server).verifyProof, } var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){ @@ -389,8 +394,8 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { var hash util.Uint256 - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } @@ -416,7 +421,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro return nil, response.NewInternalServerError(fmt.Sprintf("Problem locating block with hash: %s", hash), err) } - if len(reqParams) == 2 && reqParams[1].Value == 1 { + if reqParams.Value(1).GetBoolean() { return result.NewBlock(block, s.chain), nil } writer := io.NewBufBinWriter() @@ -425,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro } func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return nil, response.ErrInvalidParams } num, err := s.blockHeightFromParam(param) @@ -463,20 +468,15 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) } func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } return validateAddress(param.Value), nil } func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - - paramAssetID, err := param.GetUint256() + paramAssetID, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -490,12 +490,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response // getApplicationLog returns the contract log based on the specified txid. func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - txHash, err := param.GetUint256() + txHash, err := reqParams.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -522,10 +517,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp } func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } + p := ps.ValueWithType(0, request.StringT) u, err := p.GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams @@ -574,11 +566,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) } func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromHex() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -607,11 +595,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro } func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -708,25 +692,96 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 return d, nil } -func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { - param, ok := ps.Value(0) - if !ok { +func (s *Server) getProof(ps request.Params) (interface{}, *response.Error) { + root, err := ps.Value(0).GetUint256() + if err != nil { return nil, response.ErrInvalidParams } + sc, err := ps.Value(1).GetUint160FromHex() + if err != nil { + return nil, response.ErrInvalidParams + } + sc = sc.Reverse() + key, err := ps.Value(2).GetBytesHex() + if err != nil { + return nil, response.ErrInvalidParams + } + skey := mpt.ToNeoStorageKey(append(sc.BytesBE(), key...)) + proof, err := s.chain.GetStateProof(root, skey) + return &result.GetProof{ + Result: result.ProofWithKey{ + Key: skey, + Proof: proof, + }, + Success: err == nil, + }, nil +} - scriptHash, err := param.GetUint160FromHex() +func (s *Server) verifyProof(ps request.Params) (interface{}, *response.Error) { + root, err := ps.Value(0).GetUint256() + if err != nil { + return nil, response.ErrInvalidParams + } + proofStr, err := ps.Value(1).GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + var p result.ProofWithKey + if err := p.FromString(proofStr); err != nil { + return nil, response.ErrInvalidParams + } + vp := new(result.VerifyProof) + val, ok := mpt.VerifyProof(root, p.Key, p.Proof) + if ok { + var si state.StorageItem + r := io.NewBinReaderFromBuf(val[1:]) + si.DecodeBinary(r) + if r.Err != nil { + return nil, response.NewInternalServerError("invalid item in trie", r.Err) + } + vp.Value = si.Value + } + return vp, nil +} + +func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) { + return &result.StateHeight{ + BlockHeight: s.chain.BlockHeight(), + StateHeight: s.chain.StateHeight(), + }, nil +} + +func (s *Server) getStateRoot(ps request.Params) (interface{}, *response.Error) { + p := ps.Value(0) + if p == nil { + return nil, response.NewRPCError("Invalid parameter.", "", nil) + } + var rt *state.MPTRootState + var h util.Uint256 + height, err := p.GetInt() + if err == nil { + rt, err = s.chain.GetStateRoot(uint32(height)) + } else if h, err = p.GetUint256(); err == nil { + hdr, err := s.chain.GetHeader(h) + if err == nil { + rt, err = s.chain.GetStateRoot(hdr.Index) + } + } + if err != nil { + return nil, response.NewRPCError("Unknown state root.", "", err) + } + return rt, nil +} + +func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { + scriptHash, err := ps.Value(0).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } scriptHash = scriptHash.Reverse() - param, ok = ps.Value(1) - if !ok { - return nil, response.ErrInvalidParams - } - - key, err := param.GetBytesHex() + key, err := ps.Value(1).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -743,30 +798,17 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp var resultsErr *response.Error var results interface{} - if param0, ok := reqParams.Value(0); !ok { - return nil, response.ErrInvalidParams - } else if txHash, err := param0.GetUint256(); err != nil { + if txHash, err := reqParams.Value(0).GetUint256(); err != nil { resultsErr = response.ErrInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) return nil, response.NewRPCError("Unknown transaction", err.Error(), err) - } else if len(reqParams) >= 2 { + } else if reqParams.Value(1).GetBoolean() { _header := s.chain.GetHeaderHash(int(height)) header, err := s.chain.GetHeader(_header) if err != nil { resultsErr = response.NewInvalidParamsError(err.Error(), err) - } - - param1, _ := reqParams.Value(1) - switch v := param1.Value.(type) { - - case int, float64, bool, string: - if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" { - results = hex.EncodeToString(tx.Bytes()) - } else { - results = result.NewTransactionOutputRaw(tx, header, s.chain) - } - default: + } else { results = result.NewTransactionOutputRaw(tx, header, s.chain) } } else { @@ -777,12 +819,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp } func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -796,22 +833,12 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response } func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - p, ok = ps.ValueWithType(1, request.NumberT) - if !ok { - return nil, response.ErrInvalidParams - } - - num, err := p.GetInt() + num, err := ps.ValueWithType(1, request.NumberT).GetInt() if err != nil || num < 0 { return nil, response.ErrInvalidParams } @@ -833,18 +860,15 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { var results interface{} - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromHex(); err != nil { + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() + if err != nil { return nil, response.ErrInvalidParams + } + cs := s.chain.GetContractState(scriptHash) + if cs != nil { + results = result.NewContractState(cs) } else { - cs := s.chain.GetContractState(scriptHash) - if cs != nil { - results = result.NewContractState(cs) - } else { - return nil, response.NewRPCError("Unknown contract", "", nil) - } + return nil, response.NewRPCError("Unknown contract", "", nil) } return results, nil } @@ -862,33 +886,31 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in var resultsErr *response.Error var results interface{} - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromAddress(); err != nil { + param := reqParams.ValueWithType(0, request.StringT) + scriptHash, err := param.GetUint160FromAddress() + if err != nil { return nil, response.ErrInvalidParams + } + as := s.chain.GetAccountState(scriptHash) + if as == nil { + as = state.NewAccount(scriptHash) + } + if unspents { + str, err := param.GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + results = result.NewUnspents(as, s.chain, str) } else { - as := s.chain.GetAccountState(scriptHash) - if as == nil { - as = state.NewAccount(scriptHash) - } - if unspents { - str, err := param.GetString() - if err != nil { - return nil, response.ErrInvalidParams - } - results = result.NewUnspents(as, s.chain, str) - } else { - results = result.NewAccountState(as) - } + results = result.NewAccountState(as) } return results, resultsErr } // getBlockSysFee returns the system fees of the block, based on the specified index. func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return 0, response.ErrInvalidParams } @@ -913,26 +935,12 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons // getBlockHeader returns the corresponding block header information according to the specified script hash. func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { - var verbose bool - - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - hash, err := param.GetUint256() + hash, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - param, ok = reqParams.ValueWithType(1, request.NumberT) - if ok { - v, err := param.GetInt() - if err != nil { - return nil, response.ErrInvalidParams - } - verbose = v != 0 - } - + verbose := reqParams.Value(1).GetBoolean() h, err := s.chain.GetHeader(hash) if err != nil { return nil, response.NewRPCError("unknown block", "", nil) @@ -952,11 +960,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons // getUnclaimed returns unclaimed GAS amount of the specified address. func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -997,19 +1001,11 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) // invoke implements the `invoke` RPC call. func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } - sliceP, ok := reqParams.ValueWithType(1, request.ArrayT) - if !ok { - return nil, response.ErrInvalidParams - } - slice, err := sliceP.GetArray() + slice, err := reqParams.ValueWithType(1, request.ArrayT).GetArray() if err != nil { return nil, response.ErrInvalidParams } @@ -1022,11 +1018,7 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) // invokescript implements the `invokescript` RPC call. func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1069,11 +1061,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke { // submitBlock broadcasts a raw block over the NEO network. func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - blockBytes, err := param.GetBytesHex() + blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1134,11 +1122,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res // subscribe handles subscription requests from websocket clients. func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - streamName, err := p.GetString() + streamName, err := reqParams.Value(0).GetString() if err != nil { return nil, response.ErrInvalidParams } @@ -1148,8 +1132,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface } // Optional filter. var filter interface{} - p, ok = reqParams.Value(1) - if ok { + if p := reqParams.Value(1); p != nil { // It doesn't accept filters. if event == response.BlockEventID { return nil, response.ErrInvalidParams @@ -1224,11 +1207,7 @@ func (s *Server) subscribeToChannel(event response.EventID) { // unsubscribe handles unsubscription requests from websocket clients. func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, err := p.GetInt() + id, err := reqParams.Value(0).GetInt() if err != nil || id < 0 { return nil, response.ErrInvalidParams } diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 6103156f7..509867121 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -16,6 +16,8 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/address" @@ -213,6 +215,54 @@ var rpcTestCases = map[string][]rpcTestCase{ }, }, }, + "getproof": { + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid root", + params: `["0xabcdef"]`, + fail: true, + }, + { + name: "invalid contract", + params: `["0000000000000000000000000000000000000000000000000000000000000000", "0xabcdef"]`, + fail: true, + }, + { + name: "invalid key", + params: `["0000000000000000000000000000000000000000000000000000000000000000", "` + testContractHash + `", "notahex"]`, + fail: true, + }, + }, + "getstateheight": { + { + name: "positive", + params: `[]`, + result: func(_ *executor) interface{} { return new(result.StateHeight) }, + check: func(t *testing.T, e *executor, res interface{}) { + sh, ok := res.(*result.StateHeight) + require.True(t, ok) + + require.Equal(t, e.chain.BlockHeight(), sh.BlockHeight) + require.Equal(t, e.chain.StateHeight(), sh.StateHeight) + }, + }, + }, + "getstateroot": { + { + name: "no params", + params: `[]`, + fail: true, + }, + { + name: "invalid hash", + params: `["0x1234567890"]`, + fail: true, + }, + }, "getstorage": { { name: "positive", @@ -928,6 +978,52 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] }) } + t.Run("getproof", func(t *testing.T) { + r, err := chain.GetStateRoot(205) + require.NoError(t, err) + + rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getproof", "params": ["%s", "%s", "%x"]}`, + r.Root.StringLE(), testContractHash, []byte("testkey")) + fmt.Println(rpc) + body := doRPCCall(rpc, httpSrv.URL, t) + fmt.Println(string(body)) + rawRes := checkErrGetResult(t, body, false) + res := new(result.GetProof) + require.NoError(t, json.Unmarshal(rawRes, res)) + require.True(t, res.Success) + h, _ := hex.DecodeString(testContractHash) + skey := append(h, []byte("testkey")...) + require.Equal(t, mpt.ToNeoStorageKey(skey), res.Result.Key) + require.True(t, len(res.Result.Proof) > 0) + + rpc = fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "verifyproof", "params": ["%s", "%s"]}`, + r.Root.StringLE(), res.Result.String()) + body = doRPCCall(rpc, httpSrv.URL, t) + rawRes = checkErrGetResult(t, body, false) + vp := new(result.VerifyProof) + require.NoError(t, json.Unmarshal(rawRes, vp)) + require.Equal(t, []byte("testvalue"), vp.Value) + }) + + t.Run("getstateroot", func(t *testing.T) { + testRoot := func(t *testing.T, p string) { + rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getstateroot", "params": [%s]}`, p) + fmt.Println(rpc) + body := doRPCCall(rpc, httpSrv.URL, t) + rawRes := checkErrGetResult(t, body, false) + + res := new(state.MPTRootState) + require.NoError(t, json.Unmarshal(rawRes, res)) + require.NotEqual(t, util.Uint256{}, res.Root) // be sure this test uses valid height + + expected, err := e.chain.GetStateRoot(205) + require.NoError(t, err) + require.Equal(t, expected, res) + } + t.Run("ByHeight", func(t *testing.T) { testRoot(t, strconv.FormatInt(205, 10)) }) + t.Run("ByHash", func(t *testing.T) { testRoot(t, `"`+chain.GetHeaderHash(205).StringLE()+`"`) }) + }) + t.Run("getrawtransaction", func(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[1].Hash()