Merge pull request #1096 from nspcc-dev/neox-2.x

Neox 2.x
This commit is contained in:
Roman Khimov 2020-06-24 16:08:19 +03:00 committed by GitHub
commit fbadb317f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 3804 additions and 347 deletions

View file

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util"
)
@ -33,35 +34,7 @@ func toNeoStorageKey(key []byte) []byte {
if len(key) < util.Uint160Size {
panic("invalid key in storage")
}
var nkey []byte
for i := util.Uint160Size - 1; i >= 0; i-- {
nkey = append(nkey, key[i])
}
key = key[util.Uint160Size:]
index := 0
remain := len(key)
for remain >= 16 {
nkey = append(nkey, key[index:index+16]...)
nkey = append(nkey, 0)
index += 16
remain -= 16
}
if remain > 0 {
nkey = append(nkey, key[index:]...)
}
padding := 16 - remain
for i := 0; i < padding; i++ {
nkey = append(nkey, 0)
}
nkey = append(nkey, byte(padding))
return nkey
return mpt.ToNeoStorageKey(key)
}
// batchToMap converts batch to a map so that JSON is compatible

View file

@ -2,6 +2,8 @@ ProtocolConfiguration:
Magic: 1953787457
AddressVersion: 23
SecondsPerBlock: 15
EnableStateRoot: true
StateRootEnableIndex: 4380100
LowPriorityThreshold: 0.000
MemPoolSize: 50000
StandbyValidators:

View file

@ -2,6 +2,7 @@ ProtocolConfiguration:
Magic: 56753
AddressVersion: 23
SecondsPerBlock: 15
EnableStateRoot: true
LowPriorityThreshold: 0.000
MemPoolSize: 50000
StandbyValidators:

3
go.mod
View file

@ -3,12 +3,13 @@ module github.com/nspcc-dev/neo-go
require (
github.com/Workiva/go-datastructures v1.0.50
github.com/alicebob/miniredis v2.5.0+incompatible
github.com/btcsuite/btcd v0.20.1-beta
github.com/dgraph-io/badger/v2 v2.0.3
github.com/go-redis/redis v6.10.2+incompatible
github.com/go-yaml/yaml v2.1.0+incompatible
github.com/gorilla/websocket v1.4.2
github.com/mr-tron/base58 v1.1.2
github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac
github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e
github.com/nspcc-dev/rfc6979 v0.2.0
github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v1.2.1

32
go.sum
View file

@ -13,6 +13,8 @@ github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma
github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg=
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8=
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db/go.mod h1:rB3B4rKii8V21ydCbIzH5hZiCQE7f5E9SzUb/ZZx530=
github.com/aead/siphash v1.0.1 h1:FwHfE/T45KPKYuuSAKyyvE+oPWcaQ+CUmFW0bPlM+kg=
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
@ -28,6 +30,22 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/btcsuite/btcd v0.20.1-beta h1:Ik4hyJqN8Jfyv3S4AGBOmyouMsYE3EdYODkMbQjwPGw=
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f h1:bAs4lUbRJpnnkd9VhRV3jjAVU7DJVjMaK+IsvSeZvFo=
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d h1:yJzD/yFppdVCf6ApMkVy8cUxV0XrxdP9rVf6D87/Mng=
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg=
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd h1:R/opQEbFEy9JGkIguV40SvRY1uliPX8ifOvi6ICsFCw=
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg=
github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd h1:qdGvebPBDuYDPGi1WCPjy1tGyMpmDK8IEapSsszn7HE=
github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY=
github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723 h1:ZA/jbKoGcVAnER6pCHPEkGdZOV7U1oLUedErBHCUMs0=
github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc=
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 h1:R8vQdOQdZ9Y3SkEwmHoWBmX1DNXhXZqlTpq6s4tyJGc=
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY=
github.com/btcsuite/winsvc v1.0.0 h1:J9B4L7e3oqhXOcm+2IuNApwzQec85lE+QaikUcCs+dk=
github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.0 h1:yTUvW7Vhb89inJ+8irsUqiWjh8iT6sQPZiQzI6ReGkA=
@ -41,6 +59,7 @@ github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -91,9 +110,15 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89 h1:12K8AlpT0/6QUXSfV0yi4Q0jkbq8NDtIKFtF61AoqV0=
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jrick/logrotate v1.0.0 h1:lQ1bL/n9mBNeIXoTUoYRlK4dHuNJVofX9oWqBtPnSzI=
github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23 h1:FOOIBWrEkLgmlgGfMuZT83xIwfPDxEI2OHu6xUmJMFE=
github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@ -131,8 +156,8 @@ github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a h1:ajvxgEe9qY4vvoSm
github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a/go.mod h1:/YFK+XOxxg0Bfm6P92lY5eDSLYfp06XOdL8KAVgXjVk=
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1 h1:yEx9WznS+rjE0jl0dLujCxuZSIb+UTjF+005TJu/nNI=
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1/go.mod h1:O0qtn62prQSqizzoagHmuuKoz8QMkU3SzBoKdEvm3aQ=
github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac h1:cXPgsp4avJ7cR1nPRdpFRHmWoMSRZ41FSvlNjpsyTiA=
github.com/nspcc-dev/dbft v0.0.0-20200531081613-7a39e7b757ac/go.mod h1:1FYQXSbb6/9HQIkoF8XO7W/S8N7AZRkBsgwbcXRvk0E=
github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e h1:QOT9slflIkEKb5wY0ZUC0dCmCgoqGlhOAh9+xWMIxfg=
github.com/nspcc-dev/dbft v0.0.0-20200623100921-5a182c20965e/go.mod h1:1FYQXSbb6/9HQIkoF8XO7W/S8N7AZRkBsgwbcXRvk0E=
github.com/nspcc-dev/neo-go v0.73.1-pre.0.20200303142215-f5a1b928ce09/go.mod h1:pPYwPZ2ks+uMnlRLUyXOpLieaDQSEaf4NM3zHVbRjmg=
github.com/nspcc-dev/neofs-crypto v0.2.0 h1:ftN+59WqxSWz/RCgXYOfhmltOOqU+udsNQSvN6wkFck=
github.com/nspcc-dev/neofs-crypto v0.2.0/go.mod h1:F/96fUzPM3wR+UGsPi3faVNmFlA9KAEAUQR7dMxZmNA=
@ -144,10 +169,12 @@ github.com/nspcc-dev/rfc6979 v0.2.0 h1:3e1WNxrN60/6N0DW7+UYisLeZJyfqZTNOjeV/toYv
github.com/nspcc-dev/rfc6979 v0.2.0/go.mod h1:exhIh1PdpDC5vQmyEsGvc4YDM/lyQp/452QxGq/UEso=
github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.10.3 h1:OoxbjfXVZyod1fmWYhI7SEyaD8B00ynP3T+D5GiyHOY=
github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I=
github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
@ -212,6 +239,7 @@ go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=

View file

@ -11,6 +11,10 @@ var syscalls = map[string]map[string]string{
"GetUsage": "Neo.Attribute.GetUsage",
"GetData": "Neo.Attribute.GetData",
},
"crypto": {
"Secp256k1Recover": "Neo.Cryptography.Secp256k1Recover",
"Secp256r1Recover": "Neo.Cryptography.Secp256r1Recover",
},
"enumerator": {
"Concat": "Neo.Enumerator.Concat",
"Create": "Neo.Enumerator.Create",

View file

@ -20,6 +20,8 @@ const (
type (
ProtocolConfiguration struct {
AddressVersion byte `yaml:"AddressVersion"`
// EnableStateRoot specifies if exchange of state roots should be enabled.
EnableStateRoot bool `yaml:"EnableStateRoot"`
// FeePerExtraByte sets the expected per-byte fee for
// transactions exceeding the MaxFreeTransactionSize.
FeePerExtraByte float64 `yaml:"FeePerExtraByte"`
@ -34,11 +36,13 @@ type (
MaxFreeTransactionsPerBlock int `yaml:"MaxFreeTransactionsPerBlock"`
MemPoolSize int `yaml:"MemPoolSize"`
// SaveStorageBatch enables storage batch saving before every persist.
SaveStorageBatch bool `yaml:"SaveStorageBatch"`
SecondsPerBlock int `yaml:"SecondsPerBlock"`
SeedList []string `yaml:"SeedList"`
StandbyValidators []string `yaml:"StandbyValidators"`
SystemFee SystemFee `yaml:"SystemFee"`
SaveStorageBatch bool `yaml:"SaveStorageBatch"`
SecondsPerBlock int `yaml:"SecondsPerBlock"`
SeedList []string `yaml:"SeedList"`
StandbyValidators []string `yaml:"StandbyValidators"`
// StateRootEnableIndex specifies starting height for state root calculations and exchange.
StateRootEnableIndex uint32 `yaml:"StateRootEnableIndex"`
SystemFee SystemFee `yaml:"SystemFee"`
// Whether to verify received blocks.
VerifyBlocks bool `yaml:"VerifyBlocks"`
// Whether to verify transactions in received blocks.

View file

@ -8,6 +8,9 @@ import (
// commit represents dBFT Commit message.
type commit struct {
signature [signatureSize]byte
stateSig [signatureSize]byte
stateRootEnabled bool
}
// signatureSize is an rfc6989 signature size in bytes
@ -19,11 +22,17 @@ var _ payload.Commit = (*commit)(nil)
// EncodeBinary implements io.Serializable interface.
func (c *commit) EncodeBinary(w *io.BinWriter) {
w.WriteBytes(c.signature[:])
if c.stateRootEnabled {
w.WriteBytes(c.stateSig[:])
}
}
// DecodeBinary implements io.Serializable interface.
func (c *commit) DecodeBinary(r *io.BinReader) {
r.ReadBytes(c.signature[:])
if c.stateRootEnabled {
r.ReadBytes(c.stateSig[:])
}
}
// Signature implements payload.Commit interface.

View file

@ -2,6 +2,7 @@ package consensus
import (
"errors"
"fmt"
"math/rand"
"sort"
"time"
@ -13,7 +14,9 @@ import (
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/core"
coreb "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/cache"
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -49,9 +52,9 @@ type service struct {
log *zap.Logger
// cache is a fifo cache which stores recent payloads.
cache *relayCache
cache *cache.HashCache
// txx is a fifo cache which stores miner transactions.
txx *relayCache
txx *cache.HashCache
dbft *dbft.DBFT
// messages and transactions are channels needed to process
// everything in single thread.
@ -70,7 +73,7 @@ type Config struct {
Logger *zap.Logger
// Broadcast is a callback which is called to notify server
// about new consensus payload to sent.
Broadcast func(p *Payload)
Broadcast func(cache.Hashable)
// Chain is a core.Blockchainer instance.
Chain core.Blockchainer
// RequestTx is a callback to which will be called
@ -96,8 +99,8 @@ func NewService(cfg Config) (Service, error) {
Config: cfg,
log: cfg.Logger,
cache: newFIFOCache(cacheMaxCapacity),
txx: newFIFOCache(cacheMaxCapacity),
cache: cache.NewFIFOCache(cacheMaxCapacity),
txx: cache.NewFIFOCache(cacheMaxCapacity),
messages: make(chan Payload, 100),
transactions: make(chan *transaction.Transaction, 100),
@ -120,7 +123,6 @@ func NewService(cfg Config) (Service, error) {
dbft.WithLogger(srv.log),
dbft.WithSecondsPerBlock(cfg.TimePerBlock),
dbft.WithGetKeyPair(srv.getKeyPair),
dbft.WithTxPerBlock(10000),
dbft.WithRequestTx(cfg.RequestTx),
dbft.WithGetTx(srv.getTx),
dbft.WithGetVerified(srv.getVerifiedTx),
@ -135,13 +137,14 @@ func NewService(cfg Config) (Service, error) {
dbft.WithGetValidators(srv.getValidators),
dbft.WithGetConsensusAddress(srv.getConsensusAddress),
dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { p := new(Payload); p.message = &message{}; return p }),
dbft.WithNewPrepareRequest(func() payload.PrepareRequest { return new(prepareRequest) }),
dbft.WithNewConsensusPayload(srv.newPayload),
dbft.WithNewPrepareRequest(srv.newPrepareRequest),
dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }),
dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }),
dbft.WithNewCommit(func() payload.Commit { return new(commit) }),
dbft.WithNewCommit(srv.newCommit),
dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }),
dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }),
dbft.WithVerifyPrepareRequest(srv.verifyRequest),
)
if srv.dbft == nil {
@ -210,6 +213,53 @@ func (s *service) eventLoop() {
}
}
func (s *service) newPayload() payload.ConsensusPayload {
return &Payload{
message: &message{
stateRootEnabled: s.stateRootEnabled(),
},
}
}
// stateRootEnabled checks if state root feature is enabled on current height.
// It should be called only from dbft callbacks and is not protected by any mutex.
func (s *service) stateRootEnabled() bool {
return s.Chain.GetConfig().EnableStateRoot
}
func (s *service) newPrepareRequest() payload.PrepareRequest {
if !s.stateRootEnabled() {
return new(prepareRequest)
}
sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight())
if err == nil {
return &prepareRequest{
stateRootEnabled: true,
proposalStateRoot: sr.MPTRootBase,
}
}
return &prepareRequest{stateRootEnabled: true}
}
func (s *service) newCommit() payload.Commit {
if !s.stateRootEnabled() {
return new(commit)
}
c := &commit{stateRootEnabled: true}
for _, p := range s.dbft.Context.PreparationPayloads {
if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType {
pr := p.GetPrepareRequest().(*prepareRequest)
data := pr.proposalStateRoot.GetSignedPart()
sign, err := s.dbft.Priv.Sign(data)
if err == nil {
copy(c.stateSig[:], sign)
}
break
}
}
return c
}
func (s *service) validatePayload(p *Payload) bool {
validators := s.getValidators()
if int(p.validatorIndex) >= len(validators) {
@ -262,8 +312,8 @@ func (s *service) OnPayload(cp *Payload) {
// decode payload data into message
if cp.message == nil {
if err := cp.decodeData(); err != nil {
log.Debug("can't decode payload data")
if err := cp.decodeData(s.stateRootEnabled()); err != nil {
log.Debug("can't decode payload data", zap.Error(err))
return
}
}
@ -340,6 +390,21 @@ func (s *service) verifyBlock(b block.Block) bool {
return true
}
func (s *service) verifyRequest(p payload.ConsensusPayload) error {
if !s.stateRootEnabled() {
return nil
}
r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1)
if err != nil {
return fmt.Errorf("can't get local state root: %v", err)
}
rb := &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot
if !r.Equals(rb) {
return errors.New("state root mismatch")
}
return nil
}
func (s *service) processBlock(b block.Block) {
bb := &b.(*neoBlock).Block
bb.Script = *(s.getBlockWitness(bb))
@ -351,16 +416,36 @@ func (s *service) processBlock(b block.Block) {
s.log.Warn("error on add block", zap.Error(err))
}
}
var rb *state.MPTRootBase
for _, p := range s.dbft.PreparationPayloads {
if p != nil && p.Type() == payload.PrepareRequestType {
rb = &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot
}
}
w := s.getWitness(func(p payload.Commit) []byte { return p.(*commit).stateSig[:] })
r := &state.MPTRoot{
MPTRootBase: *rb,
Witness: w,
}
if err := s.Chain.AddStateRoot(r); err != nil {
s.log.Warn("errors while adding state root", zap.Error(err))
}
s.Broadcast(r)
}
func (s *service) getBlockWitness(b *coreb.Block) *transaction.Witness {
func (s *service) getBlockWitness(_ *coreb.Block) *transaction.Witness {
return s.getWitness(func(p payload.Commit) []byte { return p.Signature() })
}
func (s *service) getWitness(f func(p payload.Commit) []byte) *transaction.Witness {
dctx := s.dbft.Context
pubs := convertKeys(dctx.Validators)
sigs := make(map[*keys.PublicKey][]byte)
for i := range pubs {
if p := dctx.CommitPayloads[i]; p != nil && p.ViewNumber() == dctx.ViewNumber {
sigs[pubs[i]] = p.GetCommit().Signature()
sigs[pubs[i]] = f(p.GetCommit())
}
}
@ -397,7 +482,7 @@ func (s *service) getBlock(h util.Uint256) block.Block {
return &neoBlock{Block: *b}
}
func (s *service) getVerifiedTx(count int) []block.Transaction {
func (s *service) getVerifiedTx() []block.Transaction {
pool := s.Config.Chain.GetMemPool()
var txx []mempool.TxWithFee

View file

@ -7,6 +7,7 @@ import (
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/cache"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
@ -25,7 +26,7 @@ func TestNewService(t *testing.T) {
require.NoError(t, srv.Chain.PoolTx(tx))
var txx []block.Transaction
require.NotPanics(t, func() { txx = srv.getVerifiedTx(1) })
require.NotPanics(t, func() { txx = srv.getVerifiedTx() })
require.Len(t, txx, 2)
require.Equal(t, tx, txx[1])
srv.Chain.Close()
@ -58,7 +59,7 @@ func TestService_GetVerified(t *testing.T) {
srv.dbft.ViewNumber = 1
t.Run("new transactions will be proposed in case of failure", func(t *testing.T) {
txx := srv.getVerifiedTx(10)
txx := srv.getVerifiedTx()
require.Equal(t, 2, len(txx), "there is only 1 tx in mempool")
require.Equal(t, txs[3], txx[1])
})
@ -68,7 +69,7 @@ func TestService_GetVerified(t *testing.T) {
require.NoError(t, srv.Chain.PoolTx(tx))
}
txx := srv.getVerifiedTx(10)
txx := srv.getVerifiedTx()
require.Contains(t, txx, txs[0])
require.Contains(t, txx, txs[1])
require.NotContains(t, txx, txs[2])
@ -182,7 +183,7 @@ func shouldNotReceive(t *testing.T, ch chan Payload) {
func newTestService(t *testing.T) *service {
srv, err := NewService(Config{
Logger: zaptest.NewLogger(t),
Broadcast: func(*Payload) {},
Broadcast: func(cache.Hashable) {},
Chain: newTestChain(t),
RequestTx: func(...util.Uint256) {},
Wallet: &wallet.Config{

View file

@ -22,6 +22,8 @@ type (
Type messageType
ViewNumber byte
stateRootEnabled bool
payload io.Serializable
}
@ -283,15 +285,21 @@ func (m *message) DecodeBinary(r *io.BinReader) {
cv.newViewNumber = m.ViewNumber + 1
m.payload = cv
case prepareRequestType:
m.payload = new(prepareRequest)
m.payload = &prepareRequest{
stateRootEnabled: m.stateRootEnabled,
}
case prepareResponseType:
m.payload = new(prepareResponse)
case commitType:
m.payload = new(commit)
m.payload = &commit{
stateRootEnabled: m.stateRootEnabled,
}
case recoveryRequestType:
m.payload = new(recoveryRequest)
case recoveryMessageType:
m.payload = new(recoveryMessage)
m.payload = &recoveryMessage{
stateRootEnabled: m.stateRootEnabled,
}
default:
r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type))
return
@ -319,9 +327,9 @@ func (t messageType) String() string {
}
}
// decode data of payload into it's message
func (p *Payload) decodeData() error {
m := new(message)
// decodeData decodes data of payload into it's message.
func (p *Payload) decodeData(stateRootEnabled bool) error {
m := &message{stateRootEnabled: stateRootEnabled}
br := io.NewBinReaderFromBuf(p.data)
m.DecodeBinary(br)
if br.Err != nil {

View file

@ -94,13 +94,13 @@ func TestConsensusPayload_Serializable(t *testing.T) {
// message is nil after decoding as we didn't yet call decodeData
require.Nil(t, actual.message)
// message should now be decoded from actual.data byte array
assert.NoError(t, actual.decodeData())
assert.NoError(t, actual.decodeData(false))
require.Equal(t, p, actual)
data = p.MarshalUnsigned()
pu := new(Payload)
require.NoError(t, pu.UnmarshalUnsigned(data))
assert.NoError(t, pu.decodeData())
assert.NoError(t, pu.decodeData(false))
p.Witness = transaction.Witness{}
require.Equal(t, p, pu)
@ -144,14 +144,14 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
p := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, p))
// decode `data` into `message`
assert.NoError(t, p.decodeData())
assert.NoError(t, p.decodeData(false))
require.Equal(t, expected, p)
// invalid type
buf[typeIndex] = 0xFF
actual := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, actual))
require.Error(t, actual.decodeData())
require.Error(t, actual.decodeData(false))
// invalid format
buf[delimeterIndex] = 0
@ -165,9 +165,16 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
require.Error(t, testserdes.DecodeBinary(buf, new(Payload)))
}
func testEncodeDecode(srEnabled bool, mt messageType, actual io.Serializable) func(t *testing.T) {
return func(t *testing.T) {
expected := randomMessage(t, mt, srEnabled)
testserdes.EncodeDecodeBinary(t, expected, actual)
}
}
func TestCommit_Serializable(t *testing.T) {
c := randomMessage(t, commitType)
testserdes.EncodeDecodeBinary(t, c, new(commit))
t.Run("WithStateRoot", testEncodeDecode(true, commitType, &commit{stateRootEnabled: true}))
t.Run("NoStateRoot", testEncodeDecode(false, commitType, &commit{stateRootEnabled: false}))
}
func TestPrepareResponse_Serializable(t *testing.T) {
@ -176,8 +183,8 @@ func TestPrepareResponse_Serializable(t *testing.T) {
}
func TestPrepareRequest_Serializable(t *testing.T) {
req := randomMessage(t, prepareRequestType)
testserdes.EncodeDecodeBinary(t, req, new(prepareRequest))
t.Run("WithStateRoot", testEncodeDecode(true, prepareRequestType, &prepareRequest{stateRootEnabled: true}))
t.Run("NoStateRoot", testEncodeDecode(false, prepareRequestType, &prepareRequest{stateRootEnabled: false}))
}
func TestRecoveryRequest_Serializable(t *testing.T) {
@ -186,8 +193,8 @@ func TestRecoveryRequest_Serializable(t *testing.T) {
}
func TestRecoveryMessage_Serializable(t *testing.T) {
msg := randomMessage(t, recoveryMessageType)
testserdes.EncodeDecodeBinary(t, msg, new(recoveryMessage))
t.Run("WithStateRoot", testEncodeDecode(true, recoveryMessageType, &recoveryMessage{stateRootEnabled: true}))
t.Run("NoStateRoot", testEncodeDecode(false, recoveryMessageType, &recoveryMessage{stateRootEnabled: false}))
}
func randomPayload(t *testing.T, mt messageType) *Payload {
@ -215,31 +222,35 @@ func randomPayload(t *testing.T, mt messageType) *Payload {
return p
}
func randomMessage(t *testing.T, mt messageType) io.Serializable {
func randomMessage(t *testing.T, mt messageType, srEnabled ...bool) io.Serializable {
switch mt {
case changeViewType:
return &changeView{
timestamp: rand.Uint32(),
}
case prepareRequestType:
return randomPrepareRequest(t)
return randomPrepareRequest(t, srEnabled...)
case prepareResponseType:
return &prepareResponse{preparationHash: random.Uint256()}
case commitType:
var c commit
random.Fill(c.signature[:])
if len(srEnabled) > 0 && srEnabled[0] {
c.stateRootEnabled = true
random.Fill(c.stateSig[:])
}
return &c
case recoveryRequestType:
return &recoveryRequest{timestamp: rand.Uint32()}
case recoveryMessageType:
return randomRecoveryMessage(t)
return randomRecoveryMessage(t, srEnabled...)
default:
require.Fail(t, "invalid type")
return nil
}
}
func randomPrepareRequest(t *testing.T) *prepareRequest {
func randomPrepareRequest(t *testing.T, srEnabled ...bool) *prepareRequest {
const txCount = 3
req := &prepareRequest{
@ -255,15 +266,22 @@ func randomPrepareRequest(t *testing.T) *prepareRequest {
}
req.nextConsensus = random.Uint160()
if len(srEnabled) > 0 && srEnabled[0] {
req.stateRootEnabled = true
req.proposalStateRoot.Index = rand.Uint32()
req.proposalStateRoot.PrevHash = random.Uint256()
req.proposalStateRoot.Root = random.Uint256()
}
return req
}
func randomRecoveryMessage(t *testing.T) *recoveryMessage {
result := randomMessage(t, prepareRequestType)
func randomRecoveryMessage(t *testing.T, srEnabled ...bool) *recoveryMessage {
result := randomMessage(t, prepareRequestType, srEnabled...)
require.IsType(t, (*prepareRequest)(nil), result)
prepReq := result.(*prepareRequest)
return &recoveryMessage{
rec := &recoveryMessage{
preparationPayloads: []*preparationCompact{
{
ValidatorIndex: 1,
@ -297,6 +315,15 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage {
payload: prepReq,
},
}
if len(srEnabled) > 0 && srEnabled[0] {
rec.stateRootEnabled = true
rec.prepareRequest.stateRootEnabled = true
for _, c := range rec.commitPayloads {
c.stateRootEnabled = true
random.Fill(c.StateSignature[:])
}
}
return rec
}
func TestPayload_Sign(t *testing.T) {

View file

@ -2,6 +2,7 @@ package consensus
import (
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
@ -14,6 +15,9 @@ type prepareRequest struct {
transactionHashes []util.Uint256
minerTx transaction.Transaction
nextConsensus util.Uint160
proposalStateRoot state.MPTRootBase
stateRootEnabled bool
}
var _ payload.PrepareRequest = (*prepareRequest)(nil)
@ -25,6 +29,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) {
w.WriteBytes(p.nextConsensus[:])
w.WriteArray(p.transactionHashes)
p.minerTx.EncodeBinary(w)
if p.stateRootEnabled {
p.proposalStateRoot.EncodeBinary(w)
}
}
// DecodeBinary implements io.Serializable interface.
@ -34,6 +41,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
r.ReadBytes(p.nextConsensus[:])
r.ReadArray(&p.transactionHashes)
p.minerTx.DecodeBinary(r)
if p.stateRootEnabled {
p.proposalStateRoot.DecodeBinary(r)
}
}
// Timestamp implements payload.PrepareRequest interface.

View file

@ -3,6 +3,7 @@ package consensus
import (
"github.com/nspcc-dev/dbft/crypto"
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/pkg/errors"
@ -16,6 +17,8 @@ type (
commitPayloads []*commitCompact
changeViewPayloads []*changeViewCompact
prepareRequest *message
stateRootEnabled bool
}
changeViewCompact struct {
@ -29,7 +32,10 @@ type (
ViewNumber byte
ValidatorIndex uint16
Signature [signatureSize]byte
StateSignature [signatureSize]byte
InvocationScript []byte
stateRootEnabled bool
}
preparationCompact struct {
@ -46,7 +52,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
var hasReq = r.ReadBool()
if hasReq {
m.prepareRequest = new(message)
m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled}
m.prepareRequest.DecodeBinary(r)
if r.Err == nil && m.prepareRequest.Type != prepareRequestType {
r.Err = errors.New("recovery message PrepareRequest has wrong type")
@ -67,7 +73,16 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
}
r.ReadArray(&m.preparationPayloads)
r.ReadArray(&m.commitPayloads)
lu := r.ReadVarUint()
if lu > state.MaxValidatorsVoted {
r.Err = errors.New("too many preparation payloads")
return
}
m.commitPayloads = make([]*commitCompact, lu)
for i := uint64(0); i < lu; i++ {
m.commitPayloads[i] = &commitCompact{stateRootEnabled: m.stateRootEnabled}
m.commitPayloads[i].DecodeBinary(r)
}
}
// EncodeBinary implements io.Serializable interface.
@ -96,7 +111,7 @@ func (p *changeViewCompact) DecodeBinary(r *io.BinReader) {
p.ValidatorIndex = r.ReadU16LE()
p.OriginalViewNumber = r.ReadB()
p.Timestamp = r.ReadU32LE()
p.InvocationScript = r.ReadVarBytes()
p.InvocationScript = r.ReadVarBytes(1024)
}
// EncodeBinary implements io.Serializable interface.
@ -112,7 +127,10 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) {
p.ViewNumber = r.ReadB()
p.ValidatorIndex = r.ReadU16LE()
r.ReadBytes(p.Signature[:])
p.InvocationScript = r.ReadVarBytes()
if p.stateRootEnabled {
r.ReadBytes(p.StateSignature[:])
}
p.InvocationScript = r.ReadVarBytes(1024)
}
// EncodeBinary implements io.Serializable interface.
@ -120,13 +138,16 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) {
w.WriteB(p.ViewNumber)
w.WriteU16LE(p.ValidatorIndex)
w.WriteBytes(p.Signature[:])
if p.stateRootEnabled {
w.WriteBytes(p.StateSignature[:])
}
w.WriteVarBytes(p.InvocationScript)
}
// DecodeBinary implements io.Serializable interface.
func (p *preparationCompact) DecodeBinary(r *io.BinReader) {
p.ValidatorIndex = r.ReadU16LE()
p.InvocationScript = r.ReadVarBytes()
p.InvocationScript = r.ReadVarBytes(1024)
}
// EncodeBinary implements io.Serializable interface.
@ -143,6 +164,8 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
Type: prepareRequestType,
ViewNumber: p.ViewNumber(),
payload: p.GetPrepareRequest().(*prepareRequest),
stateRootEnabled: m.stateRootEnabled,
}
h := p.Hash()
m.preparationHash = &h
@ -169,9 +192,11 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
})
case payload.CommitType:
m.commitPayloads = append(m.commitPayloads, &commitCompact{
stateRootEnabled: m.stateRootEnabled,
ValidatorIndex: p.ValidatorIndex(),
ViewNumber: p.ViewNumber(),
Signature: p.GetCommit().(*commit).signature,
StateSignature: p.GetCommit().(*commit).stateSig,
InvocationScript: p.(*Payload).Witness.InvocationScript,
})
}
@ -234,6 +259,7 @@ func (m *recoveryMessage) GetChangeViews(p payload.ConsensusPayload, validators
newViewNumber: cv.OriginalViewNumber + 1,
timestamp: cv.Timestamp,
})
c.message.ViewNumber = cv.OriginalViewNumber
c.SetValidatorIndex(cv.ValidatorIndex)
c.Witness.InvocationScript = cv.InvocationScript
c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators)
@ -249,7 +275,12 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr
ps := make([]payload.ConsensusPayload, len(m.commitPayloads))
for i, c := range m.commitPayloads {
cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature})
cc := fromPayload(commitType, p.(*Payload), &commit{
signature: c.Signature,
stateSig: c.StateSignature,
stateRootEnabled: m.stateRootEnabled,
})
cc.SetValidatorIndex(c.ValidatorIndex)
cc.Witness.InvocationScript = c.InvocationScript
cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators)
@ -289,6 +320,8 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload {
Type: t,
ViewNumber: recovery.message.ViewNumber,
payload: p,
stateRootEnabled: recovery.stateRootEnabled,
},
version: recovery.Version(),
prevHash: recovery.PrevHash(),

View file

@ -13,9 +13,11 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -229,6 +231,11 @@ func (bc *Blockchain) init() error {
}
bc.blockHeight = bHeight
bc.persistedHeight = bHeight
if bc.config.EnableStateRoot {
if err = bc.dao.InitMPT(bHeight); err != nil {
return errors.Wrapf(err, "can't init MPT at height %d", bHeight)
}
}
hashes, err := bc.dao.GetHeaderHashes()
if err != nil {
@ -551,6 +558,23 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 {
return sf
}
// GetStateProof returns proof of having key in the MPT with the specified root.
func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) {
if !bc.config.EnableStateRoot {
return nil, errors.New("state root feature is not enabled")
}
tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store))
return tr.GetProof(key)
}
// GetStateRoot returns state root for a given height.
func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
if !bc.config.EnableStateRoot {
return nil, errors.New("state root feature is not enabled")
}
return bc.dao.GetStateRoot(height)
}
// TODO: storeBlock needs some more love, its implemented as in the original
// project. This for the sake of development speed and understanding of what
// is happening here, quite allot as you can see :). If things are wired together
@ -819,6 +843,28 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
}
}
if bc.config.EnableStateRoot {
root := bc.dao.MPT.StateRoot()
var prevHash util.Uint256
if block.Index > 0 {
prev, err := bc.dao.GetStateRoot(block.Index - 1)
if err != nil {
return errors.WithMessagef(err, "can't get previous state root")
}
prevHash = hash.DoubleSha256(prev.GetSignedPart())
}
err := bc.AddStateRoot(&state.MPTRoot{
MPTRootBase: state.MPTRootBase{
Index: block.Index,
PrevHash: prevHash,
Root: root,
},
})
if err != nil {
return err
}
}
if bc.config.SaveStorageBatch {
bc.lastBatch = cache.DAO.GetBatch()
}
@ -829,6 +875,15 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
bc.lock.Unlock()
return err
}
if bc.config.EnableStateRoot {
bc.dao.MPT.Flush()
// Every persist cycle we also compact our in-memory MPT.
persistedHeight := atomic.LoadUint32(&bc.persistedHeight)
if persistedHeight == block.Index-1 {
// 10 is good and roughly estimated to fit remaining trie into 1M of memory.
bc.dao.MPT.Collapse(10)
}
}
bc.topBlock.Store(block)
atomic.StoreUint32(&bc.blockHeight, block.Index)
bc.memPool.RemoveStale(bc.isTxStillRelevant)
@ -1492,12 +1547,13 @@ func (bc *Blockchain) ApplyPolicyToTxSet(txes []mempool.TxWithFee) []mempool.TxW
txes = txes[:bc.config.MaxTransactionsPerBlock]
}
maxFree := bc.config.MaxFreeTransactionsPerBlock
if maxFree != 0 {
lowStart := sort.Search(len(txes), func(i int) bool {
return bc.IsLowPriority(txes[i].Fee)
if maxFree != 0 && len(txes) > maxFree {
// Transactions are sorted by fee, so we just find the first free one.
freeStart := sort.Search(len(txes), func(i int) bool {
return txes[i].Fee == 0
})
if lowStart+maxFree < len(txes) {
txes = txes[:lowStart+maxFree]
if freeStart+maxFree < len(txes) {
txes = txes[:freeStart+maxFree]
}
}
return txes
@ -1732,6 +1788,90 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool {
}
// StateHeight returns height of the verified state root.
func (bc *Blockchain) StateHeight() uint32 {
h, _ := bc.dao.GetCurrentStateRootHeight()
return h
}
// AddStateRoot add new (possibly unverified) state root to the blockchain.
func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error {
if !bc.config.EnableStateRoot {
bc.log.Warn("state root is being added but not enabled in config")
return nil
}
our, err := bc.GetStateRoot(r.Index)
if err == nil {
if our.Flag == state.Verified {
return bc.updateStateHeight(r.Index)
} else if r.Witness == nil && our.Witness != nil {
r.Witness = our.Witness
}
}
if err := bc.verifyStateRoot(r); err != nil {
return errors.WithMessage(err, "invalid state root")
}
if r.Index > bc.BlockHeight() { // just put it into the store for future checks
return bc.dao.PutStateRoot(&state.MPTRootState{
MPTRoot: *r,
Flag: state.Unverified,
})
}
flag := state.Unverified
if r.Witness != nil {
if err := bc.verifyStateRootWitness(r); err != nil {
return errors.WithMessage(err, "can't verify signature")
}
flag = state.Verified
}
err = bc.dao.PutStateRoot(&state.MPTRootState{
MPTRoot: *r,
Flag: flag,
})
if err != nil {
return err
}
return bc.updateStateHeight(r.Index)
}
func (bc *Blockchain) updateStateHeight(newHeight uint32) error {
h, err := bc.dao.GetCurrentStateRootHeight()
if err != nil {
return errors.WithMessage(err, "can't get current state root height")
} else if newHeight == h+1 {
updateStateHeightMetric(newHeight)
return bc.dao.PutCurrentStateRootHeight(h + 1)
}
return nil
}
// verifyStateRoot checks if state root is valid.
func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error {
if r.Index == 0 {
return nil
}
prev, err := bc.GetStateRoot(r.Index - 1)
if err != nil {
return errors.New("can't get previous state root")
} else if !r.PrevHash.Equals(hash.DoubleSha256(prev.GetSignedPart())) {
return errors.New("previous hash mismatch")
} else if prev.Version != r.Version {
return errors.New("version mismatch")
}
return nil
}
// verifyStateRootWitness verifies that state root signature is correct.
func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error {
b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index)))
if err != nil {
return err
}
interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil)
return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, hash.Sha256(r.GetSignedPart()), interopCtx, true)
}
// VerifyTx verifies whether a transaction is bonafide or not. Block parameter
// is used for easy interop access and can be omitted for transactions that are
// not yet added into any block.

View file

@ -18,6 +18,7 @@ type Blockchainer interface {
GetConfig() config.ProtocolConfiguration
AddHeaders(...*block.Header) error
AddBlock(*block.Block) error
AddStateRoot(r *state.MPTRoot) error
BlockHeight() uint32
CalculateClaimable(value util.Fixed8, startHeight, endHeight uint32) (util.Fixed8, util.Fixed8, error)
Close()
@ -38,6 +39,8 @@ type Blockchainer interface {
GetNEP5Balances(util.Uint160) *state.NEP5Balances
GetValidators(txes ...*transaction.Transaction) ([]*keys.PublicKey, error)
GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error)
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
GetStateRoot(height uint32) (*state.MPTRootState, error)
GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem
GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error)
GetTestVM() *vm.VM
@ -46,6 +49,7 @@ type Blockchainer interface {
References(t *transaction.Transaction) ([]transaction.InOut, error)
mempool.Feer // fee interface
PoolTx(*transaction.Transaction) error
StateHeight() uint32
SubscribeForBlocks(ch chan<- *block.Block)
SubscribeForExecutions(ch chan<- *state.AppExecResult)
SubscribeForNotifications(ch chan<- *state.NotificationEvent)

View file

@ -1,4 +1,4 @@
package consensus
package cache
import (
"container/list"
@ -7,9 +7,9 @@ import (
"github.com/nspcc-dev/neo-go/pkg/util"
)
// relayCache is a payload cache which is used to store
// HashCache is a payload cache which is used to store
// last consensus payloads.
type relayCache struct {
type HashCache struct {
*sync.RWMutex
maxCap int
@ -17,13 +17,14 @@ type relayCache struct {
queue *list.List
}
// hashable is a type of items which can be stored in the relayCache.
type hashable interface {
// Hashable is a type of items which can be stored in the HashCache.
type Hashable interface {
Hash() util.Uint256
}
func newFIFOCache(capacity int) *relayCache {
return &relayCache{
// NewFIFOCache returns new FIFO cache with the specified capacity.
func NewFIFOCache(capacity int) *HashCache {
return &HashCache{
RWMutex: new(sync.RWMutex),
maxCap: capacity,
@ -33,7 +34,7 @@ func newFIFOCache(capacity int) *relayCache {
}
// Add adds payload into a cache if it doesn't already exist.
func (c *relayCache) Add(p hashable) {
func (c *HashCache) Add(p Hashable) {
c.Lock()
defer c.Unlock()
@ -45,7 +46,7 @@ func (c *relayCache) Add(p hashable) {
if c.queue.Len() >= c.maxCap {
first := c.queue.Front()
c.queue.Remove(first)
delete(c.elems, first.Value.(hashable).Hash())
delete(c.elems, first.Value.(Hashable).Hash())
}
e := c.queue.PushBack(p)
@ -53,7 +54,7 @@ func (c *relayCache) Add(p hashable) {
}
// Has checks if an item is already in cache.
func (c *relayCache) Has(h util.Uint256) bool {
func (c *HashCache) Has(h util.Uint256) bool {
c.RLock()
defer c.RUnlock()
@ -61,13 +62,13 @@ func (c *relayCache) Has(h util.Uint256) bool {
}
// Get returns payload with the specified hash from cache.
func (c *relayCache) Get(h util.Uint256) hashable {
func (c *HashCache) Get(h util.Uint256) Hashable {
c.RLock()
defer c.RUnlock()
e, ok := c.elems[h]
if !ok {
return hashable(nil)
return Hashable(nil)
}
return e.Value.(hashable)
return e.Value.(Hashable)
}

View file

@ -1,17 +1,19 @@
package consensus
package cache
import (
"math/rand"
"testing"
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestRelayCache_Add(t *testing.T) {
const capacity = 3
payloads := getDifferentPayloads(t, capacity+1)
c := newFIFOCache(capacity)
payloads := getDifferentItems(t, capacity+1)
c := NewFIFOCache(capacity)
require.Equal(t, 0, c.queue.Len())
require.Equal(t, 0, len(c.elems))
@ -46,19 +48,15 @@ func TestRelayCache_Add(t *testing.T) {
require.Equal(t, nil, c.Get(payloads[1].Hash()))
}
func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) {
payloads = make([]Payload, n)
for i := range payloads {
var sign [signatureSize]byte
random.Fill(sign[:])
type testHashable []byte
payloads[i].message = &message{}
payloads[i].SetValidatorIndex(uint16(i))
payloads[i].SetType(payload.MessageType(commitType))
payloads[i].payload = &commit{
signature: sign,
}
// Hash implements Hashable.
func (h testHashable) Hash() util.Uint256 { return hash.Sha256(h) }
func getDifferentItems(t *testing.T, n int) []testHashable {
items := make([]testHashable, n)
for i := range items {
items[i] = random.Bytes(rand.Int() % 10)
}
return
return items
}

View file

@ -245,6 +245,7 @@ func (cd *Cached) FlushStorage() error {
return err
}
}
ti.State |= flushedState
}
}
return nil
@ -275,7 +276,7 @@ func (cd *Cached) getStorageItemNoCache(scripthash util.Uint160, key []byte) *st
func (cd *Cached) getStorageItemInt(scripthash util.Uint160, key []byte, putToCache bool) *state.StorageItem {
ti := cd.storage.getItem(scripthash, key)
if ti != nil {
if ti.State == delOp {
if ti.State&delOp != 0 {
return nil
}
return copyItem(&ti.StorageItem)
@ -303,8 +304,10 @@ func (cd *Cached) PutStorageItem(scripthash util.Uint160, key []byte, si *state.
item := copyItem(si)
ti := cd.storage.getItem(scripthash, key)
if ti != nil {
if ti.State == delOp || ti.State == getOp {
if ti.State&(delOp|getOp) != 0 {
ti.State = putOp
} else {
ti.State = addOp
}
ti.StorageItem = *item
return nil
@ -357,7 +360,7 @@ func (cd *Cached) GetStorageItemsIterator(hash util.Uint160, prefix []byte) (Sto
for ; keyIndex < len(cd.storage.keys[hash]); keyIndex++ {
k := cd.storage.keys[hash][keyIndex]
v := cache[k]
if v.State != delOp && bytes.HasPrefix([]byte(k), prefix) {
if v.State&delOp == 0 && bytes.HasPrefix([]byte(k), prefix) {
val := make([]byte, len(v.StorageItem.Value))
copy(val, v.StorageItem.Value)
return []byte(k), val, nil
@ -404,7 +407,7 @@ func (cd *Cached) GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageIt
for _, k := range cd.storage.keys[hash] {
v := cache[k]
if v.State != delOp {
if v.State&delOp == 0 {
val := make([]byte, len(v.StorageItem.Value))
copy(val, v.StorageItem.Value)
result = append(result, StorageItemWithKey{

View file

@ -7,6 +7,7 @@ import (
"sort"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -31,9 +32,12 @@ type DAO interface {
GetContractState(hash util.Uint160) (*state.Contract, error)
GetCurrentBlockHeight() (uint32, error)
GetCurrentHeaderHeight() (i uint32, h util.Uint256, err error)
GetCurrentStateRootHeight() (uint32, error)
GetHeaderHashes() ([]util.Uint256, error)
GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error)
GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error)
GetStateRoot(height uint32) (*state.MPTRootState, error)
PutStateRoot(root *state.MPTRootState) error
GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem
GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageItemWithKey, error)
GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error)
@ -70,12 +74,14 @@ type DAO interface {
// Simple is memCached wrapper around DB, simple DAO implementation.
type Simple struct {
MPT *mpt.Trie
Store *storage.MemCachedStore
}
// NewSimple creates new simple dao using provided backend store.
func NewSimple(backend storage.Store) *Simple {
return &Simple{Store: storage.NewMemCachedStore(backend)}
st := storage.NewMemCachedStore(backend)
return &Simple{Store: st, MPT: mpt.NewTrie(nil, st)}
}
// GetBatch returns currently accumulated DB changeset.
@ -86,7 +92,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch {
// GetWrapped returns new DAO instance with another layer of wrapped
// MemCachedStore around the current DAO Store.
func (dao *Simple) GetWrapped() DAO {
return NewSimple(dao.Store)
d := NewSimple(dao.Store)
d.MPT = dao.MPT
return d
}
// GetAndDecode performs get operation and decoding with serializable structures.
@ -406,6 +414,63 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error {
// -- start storage item.
func makeStateRootKey(height uint32) []byte {
key := make([]byte, 5)
key[0] = byte(storage.DataMPT)
binary.LittleEndian.PutUint32(key[1:], height)
return key
}
// InitMPT initializes MPT at the given height.
func (dao *Simple) InitMPT(height uint32) error {
if height == 0 {
dao.MPT = mpt.NewTrie(nil, dao.Store)
return nil
}
r, err := dao.GetStateRoot(height)
if err != nil {
return err
}
dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store)
return nil
}
// GetCurrentStateRootHeight returns current state root height.
func (dao *Simple) GetCurrentStateRootHeight() (uint32, error) {
key := []byte{byte(storage.DataMPT)}
val, err := dao.Store.Get(key)
if err != nil {
if err == storage.ErrKeyNotFound {
err = nil
}
return 0, err
}
return binary.LittleEndian.Uint32(val), nil
}
// PutCurrentStateRootHeight updates current state root height.
func (dao *Simple) PutCurrentStateRootHeight(height uint32) error {
key := []byte{byte(storage.DataMPT)}
val := make([]byte, 4)
binary.LittleEndian.PutUint32(val, height)
return dao.Store.Put(key, val)
}
// GetStateRoot returns state root of a given height.
func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) {
r := new(state.MPTRootState)
err := dao.GetAndDecode(r, makeStateRootKey(height))
if err != nil {
return nil, err
}
return r, nil
}
// PutStateRoot puts state root of a given height into the store.
func (dao *Simple) PutStateRoot(r *state.MPTRootState) error {
return dao.Put(r, makeStateRootKey(r.Index))
}
// GetStorageItem returns StorageItem if it exists in the given store.
func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem {
b, err := dao.Store.Get(makeStorageItemKey(scripthash, key))
@ -426,13 +491,24 @@ func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.St
// PutStorageItem puts given StorageItem for given script with given
// key into the given store.
func (dao *Simple) PutStorageItem(scripthash util.Uint160, key []byte, si *state.StorageItem) error {
return dao.Put(si, makeStorageItemKey(scripthash, key))
stKey := makeStorageItemKey(scripthash, key)
k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix
v := mpt.ToNeoStorageValue(si)
if err := dao.MPT.Put(k, v); err != nil && err != mpt.ErrNotFound {
return err
}
return dao.Store.Put(stKey, v[1:])
}
// DeleteStorageItem drops storage item for the given script with the
// given key from the store.
func (dao *Simple) DeleteStorageItem(scripthash util.Uint160, key []byte) error {
return dao.Store.Delete(makeStorageItemKey(scripthash, key))
stKey := makeStorageItemKey(scripthash, key)
k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix
if err := dao.MPT.Delete(k); err != nil && err != mpt.ErrNotFound {
return err
}
return dao.Store.Delete(stKey)
}
// StorageItemWithKey is a Key-Value pair together with possible const modifier.

View file

@ -24,6 +24,7 @@ const (
delOp
addOp
putOp
flushedState
)
func newItemCache() *itemCache {

View file

@ -1,10 +1,13 @@
package core
import (
"crypto/elliptic"
"errors"
"fmt"
"math"
"math/big"
"github.com/btcsuite/btcd/btcec"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -600,6 +603,33 @@ func (ic *interopContext) contractMigrate(v *vm.VM) error {
return ic.contractDestroy(v)
}
// secp256k1Recover recovers speck256k1 public key.
func (ic *interopContext) secp256k1Recover(v *vm.VM) error {
return ic.eccRecover(btcec.S256(), v)
}
// secp256r1Recover recovers speck256r1 public key.
func (ic *interopContext) secp256r1Recover(v *vm.VM) error {
return ic.eccRecover(elliptic.P256(), v)
}
// eccRecover recovers public key using ECCurve set
func (ic *interopContext) eccRecover(curve elliptic.Curve, v *vm.VM) error {
rBytes := v.Estack().Pop().Bytes()
sBytes := v.Estack().Pop().Bytes()
r := new(big.Int).SetBytes(rBytes)
s := new(big.Int).SetBytes(sBytes)
isEven := v.Estack().Pop().Bool()
messageHash := v.Estack().Pop().Bytes()
pKey, err := keys.KeyRecover(curve, r, s, messageHash, isEven)
if err != nil {
v.Estack().PushVal([]byte{})
return nil
}
v.Estack().PushVal(pKey.UncompressedBytes()[1:])
return nil
}
// assetCreate creates an asset.
func (ic *interopContext) assetCreate(v *vm.VM) error {
if ic.trigger != trigger.Application {

View file

@ -1,14 +1,17 @@
package core
import (
"bytes"
"math/big"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -457,8 +460,81 @@ func TestAssetGetPrecision(t *testing.T) {
require.Equal(t, big.NewInt(int64(assetState.Precision)), precision)
}
func TestSecp256k1Recover(t *testing.T) {
v, context, chain := createVM(t)
defer chain.Close()
privateKey, err := btcec.NewPrivateKey(btcec.S256())
require.NoError(t, err)
message := []byte("The quick brown fox jumps over the lazy dog")
signature, err := privateKey.Sign(message)
require.NoError(t, err)
require.True(t, signature.Verify(message, privateKey.PubKey()))
pubKey := keys.PublicKey{
X: privateKey.PubKey().X,
Y: privateKey.PubKey().Y,
}
expected := pubKey.UncompressedBytes()[1:]
// We don't know which of two recovered keys suites, so let's try both.
putOnStackGetResult := func(isEven bool) []byte {
v.Estack().PushVal(message)
v.Estack().PushVal(isEven)
v.Estack().PushVal(signature.S.Bytes())
v.Estack().PushVal(signature.R.Bytes())
err = context.secp256k1Recover(v)
require.NoError(t, err)
return v.Estack().Pop().Value().([]byte)
}
// First one:
actualFalse := putOnStackGetResult(false)
// Second one:
actualTrue := putOnStackGetResult(true)
require.True(t, bytes.Compare(expected, actualFalse) != bytes.Compare(expected, actualTrue))
}
func TestSecp256r1Recover(t *testing.T) {
v, context, chain := createVM(t)
defer chain.Close()
privateKey, err := keys.NewPrivateKey()
require.NoError(t, err)
message := []byte("The quick brown fox jumps over the lazy dog")
messageHash := hash.Sha256(message).BytesBE()
signature := privateKey.Sign(message)
require.True(t, privateKey.PublicKey().Verify(signature, messageHash))
expected := privateKey.PublicKey().UncompressedBytes()[1:]
// We don't know which of two recovered keys suites, so let's try both.
putOnStackGetResult := func(isEven bool) []byte {
v.Estack().PushVal(messageHash)
v.Estack().PushVal(isEven)
v.Estack().PushVal(signature[32:64])
v.Estack().PushVal(signature[0:32])
err = context.secp256r1Recover(v)
require.NoError(t, err)
return v.Estack().Pop().Value().([]byte)
}
// First one:
actualFalse := putOnStackGetResult(false)
// Second one:
actualTrue := putOnStackGetResult(true)
require.True(t, bytes.Compare(expected, actualFalse) != bytes.Compare(expected, actualTrue))
}
// Helper functions to create VM, InteropContext, TX, Account, Contract, Asset.
func createVM(t *testing.T) (*vm.VM, *interopContext, *Blockchain) {
v := vm.New()
chain := newTestChain(t)
context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore()), nil, nil)
return v, context, chain
}
func createVMAndPushBlock(t *testing.T) (*vm.VM, *block.Block, *interopContext, *Blockchain) {
v := vm.New()
block := newDumbBlock()

View file

@ -52,6 +52,9 @@ func (ic *interopContext) SpawnVM() *vm.VM {
})
vm.RegisterInteropGetter(ic.getSystemInterop)
vm.RegisterInteropGetter(ic.getNeoInterop)
if ic.bc != nil && ic.bc.GetConfig().EnableStateRoot {
vm.RegisterInteropGetter(ic.getNeoxInterop)
}
return vm
}
@ -77,6 +80,12 @@ func (ic *interopContext) getNeoInterop(id uint32) *vm.InteropFuncPrice {
return ic.getInteropFromSlice(id, neoInterops)
}
// getNeoxInterop returns matching interop function from the NeoX extension
// for a given id in the current context.
func (ic *interopContext) getNeoxInterop(id uint32) *vm.InteropFuncPrice {
return ic.getInteropFromSlice(id, neoxInterops)
}
// getInteropFromSlice returns matching interop function from the given slice of
// interop functions in the current context.
func (ic *interopContext) getInteropFromSlice(id uint32, slice []interopedFunction) *vm.InteropFuncPrice {
@ -276,6 +285,11 @@ var neoInterops = []interopedFunction{
{Name: "AntShares.Transaction.GetType", Func: (*interopContext).txGetType, Price: 1},
}
var neoxInterops = []interopedFunction{
{Name: "Neo.Cryptography.Secp256k1Recover", Func: (*interopContext).secp256k1Recover, Price: 100},
{Name: "Neo.Cryptography.Secp256r1Recover", Func: (*interopContext).secp256r1Recover, Price: 100},
}
// initIDinInteropsSlice initializes IDs from names in one given
// interopedFunction slice and then sorts it.
func initIDinInteropsSlice(iops []interopedFunction) {
@ -291,4 +305,5 @@ func initIDinInteropsSlice(iops []interopedFunction) {
func init() {
initIDinInteropsSlice(systemInterops)
initIDinInteropsSlice(neoInterops)
initIDinInteropsSlice(neoxInterops)
}

84
pkg/core/mpt/base.go Normal file
View file

@ -0,0 +1,84 @@
package mpt
import (
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// BaseNode implements basic things every node needs like caching hash and
// serialized representation. It's a basic node building block intended to be
// included into all node types.
type BaseNode struct {
hash util.Uint256
bytes []byte
hashValid bool
bytesValid bool
isFlushed bool
}
// BaseNodeIface abstracts away basic Node functions.
type BaseNodeIface interface {
Hash() util.Uint256
Type() NodeType
Bytes() []byte
IsFlushed() bool
SetFlushed()
}
// getHash returns a hash of this BaseNode.
func (b *BaseNode) getHash(n Node) util.Uint256 {
if !b.hashValid {
b.updateHash(n)
}
return b.hash
}
// getBytes returns a slice of bytes representing this node.
func (b *BaseNode) getBytes(n Node) []byte {
if !b.bytesValid {
b.updateBytes(n)
}
return b.bytes
}
// updateHash updates hash field for this BaseNode.
func (b *BaseNode) updateHash(n Node) {
if n.Type() == HashT {
panic("can't update hash for hash node")
}
b.hash = hash.DoubleSha256(b.getBytes(n))
b.hashValid = true
}
// updateCache updates hash and bytes fields for this BaseNode.
func (b *BaseNode) updateBytes(n Node) {
buf := io.NewBufBinWriter()
encodeNodeWithType(n, buf.BinWriter)
b.bytes = buf.Bytes()
b.bytesValid = true
}
// invalidateCache sets all cache fields to invalid state.
func (b *BaseNode) invalidateCache() {
b.bytesValid = false
b.hashValid = false
b.isFlushed = false
}
// IsFlushed checks for node flush status.
func (b *BaseNode) IsFlushed() bool {
return b.isFlushed
}
// SetFlushed sets 'flushed' flag to true for this node.
func (b *BaseNode) SetFlushed() {
b.isFlushed = true
}
// encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) {
w.WriteB(byte(n.Type()))
n.EncodeBinary(w)
}

91
pkg/core/mpt/branch.go Normal file
View file

@ -0,0 +1,91 @@
package mpt
import (
"encoding/json"
"errors"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
const (
// childrenCount represents a number of children of a branch node.
childrenCount = 17
// lastChild is the index of the last child.
lastChild = childrenCount - 1
)
// BranchNode represents MPT's branch node.
type BranchNode struct {
BaseNode
Children [childrenCount]Node
}
var _ Node = (*BranchNode)(nil)
// NewBranchNode returns new branch node.
func NewBranchNode() *BranchNode {
b := new(BranchNode)
for i := 0; i < childrenCount; i++ {
b.Children[i] = new(HashNode)
}
return b
}
// Type implements Node interface.
func (b *BranchNode) Type() NodeType { return BranchT }
// Hash implements BaseNode interface.
func (b *BranchNode) Hash() util.Uint256 {
return b.getHash(b)
}
// Bytes implements BaseNode interface.
func (b *BranchNode) Bytes() []byte {
return b.getBytes(b)
}
// EncodeBinary implements io.Serializable.
func (b *BranchNode) EncodeBinary(w *io.BinWriter) {
for i := 0; i < childrenCount; i++ {
if hn, ok := b.Children[i].(*HashNode); ok {
hn.EncodeBinary(w)
continue
}
n := NewHashNode(b.Children[i].Hash())
n.EncodeBinary(w)
}
}
// DecodeBinary implements io.Serializable.
func (b *BranchNode) DecodeBinary(r *io.BinReader) {
for i := 0; i < childrenCount; i++ {
b.Children[i] = new(HashNode)
b.Children[i].DecodeBinary(r)
}
}
// MarshalJSON implements json.Marshaler.
func (b *BranchNode) MarshalJSON() ([]byte, error) {
return json.Marshal(b.Children)
}
// UnmarshalJSON implements json.Unmarshaler.
func (b *BranchNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*BranchNode); ok {
*b = *u
return nil
}
return errors.New("expected branch node")
}
// splitPath splits path for a branch node.
func splitPath(path []byte) (byte, []byte) {
if len(path) != 0 {
return path[0], path[1:]
}
return lastChild, path
}

45
pkg/core/mpt/doc.go Normal file
View file

@ -0,0 +1,45 @@
/*
Package mpt implements MPT (Merkle-Patricia Tree).
MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie
Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node.
MPT consists of 4 type of nodes:
- Leaf node contains only value.
- Extension node contains both key and value.
- Branch node contains 2 or more children.
- Hash node is a compressed node and contains only actual node's hash.
The actual node must be retrieved from storage or over the network.
As an example here is a trie containing 3 pairs:
- 0x1201 -> val1
- 0x1203 -> val2
- 0x1224 -> val3
- 0x12 -> val4
ExtensionNode(0x0102), Next
_______________________|
|
BranchNode [0, 1, 2, ...], Last -> Leaf(val4)
| |
| ExtensionNode [0x04], Next -> Leaf(val3)
|
BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil)
| |
| Leaf(val2)
|
Leaf(val1)
There are 3 invariants that this implementation has:
- Branch node cannot have <= 1 children
- Extension node cannot have zero-length key
- Extension node cannot have another Extension node in it's next field
Thank to these restrictions, there is a single root hash for every set of key-value pairs
irregardless of the order they were added/removed with.
The actual trie structure can vary because of node -> HashNode compressing.
There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial:
When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is
replaced by its uncompressed form, so that subsequent hits of this not don't use storage.
*/
package mpt

87
pkg/core/mpt/extension.go Normal file
View file

@ -0,0 +1,87 @@
package mpt
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxKeyLength is the max length of the extension node key.
const MaxKeyLength = 1125
// ExtensionNode represents MPT's extension node.
type ExtensionNode struct {
BaseNode
key []byte
next Node
}
var _ Node = (*ExtensionNode)(nil)
// NewExtensionNode returns hash node with the specified key and next node.
// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0.
func NewExtensionNode(key []byte, next Node) *ExtensionNode {
return &ExtensionNode{
key: key,
next: next,
}
}
// Type implements Node interface.
func (e ExtensionNode) Type() NodeType { return ExtensionT }
// Hash implements BaseNode interface.
func (e *ExtensionNode) Hash() util.Uint256 {
return e.getHash(e)
}
// Bytes implements BaseNode interface.
func (e *ExtensionNode) Bytes() []byte {
return e.getBytes(e)
}
// DecodeBinary implements io.Serializable.
func (e *ExtensionNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
if sz > MaxKeyLength {
r.Err = fmt.Errorf("extension node key is too big: %d", sz)
return
}
e.key = make([]byte, sz)
r.ReadBytes(e.key)
e.next = new(HashNode)
e.next.DecodeBinary(r)
e.invalidateCache()
}
// EncodeBinary implements io.Serializable.
func (e ExtensionNode) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(e.key)
n := NewHashNode(e.next.Hash())
n.EncodeBinary(w)
}
// MarshalJSON implements json.Marshaler.
func (e *ExtensionNode) MarshalJSON() ([]byte, error) {
m := map[string]interface{}{
"key": hex.EncodeToString(e.key),
"next": e.next,
}
return json.Marshal(m)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*ExtensionNode); ok {
*e = *u
return nil
}
return errors.New("expected extension node")
}

88
pkg/core/mpt/hash.go Normal file
View file

@ -0,0 +1,88 @@
package mpt
import (
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// HashNode represents MPT's hash node.
type HashNode struct {
BaseNode
}
var _ Node = (*HashNode)(nil)
// NewHashNode returns hash node with the specified hash.
func NewHashNode(h util.Uint256) *HashNode {
return &HashNode{
BaseNode: BaseNode{
hash: h,
hashValid: true,
},
}
}
// Type implements Node interface.
func (h *HashNode) Type() NodeType { return HashT }
// Hash implements Node interface.
func (h *HashNode) Hash() util.Uint256 {
if !h.hashValid {
panic("can't get hash of an empty HashNode")
}
return h.hash
}
// IsEmpty returns true iff h is an empty node i.e. contains no hash.
func (h *HashNode) IsEmpty() bool { return !h.hashValid }
// Bytes returns serialized HashNode.
func (h *HashNode) Bytes() []byte {
return h.getBytes(h)
}
// DecodeBinary implements io.Serializable.
func (h *HashNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
switch sz {
case 0:
h.hashValid = false
case util.Uint256Size:
h.hashValid = true
r.ReadBytes(h.hash[:])
default:
r.Err = fmt.Errorf("invalid hash node size: %d", sz)
}
}
// EncodeBinary implements io.Serializable.
func (h HashNode) EncodeBinary(w *io.BinWriter) {
if !h.hashValid {
w.WriteVarUint(0)
return
}
w.WriteVarBytes(h.hash[:])
}
// MarshalJSON implements json.Marshaler.
func (h *HashNode) MarshalJSON() ([]byte, error) {
if !h.hashValid {
return []byte(`{}`), nil
}
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (h *HashNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*HashNode); ok {
*h = *u
return nil
}
return errors.New("expected hash node")
}

86
pkg/core/mpt/helpers.go Normal file
View file

@ -0,0 +1,86 @@
package mpt
import (
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// lcp returns longest common prefix of a and b.
// Note: it does no allocations.
func lcp(a, b []byte) []byte {
if len(a) < len(b) {
return lcp(b, a)
}
var i int
for i = 0; i < len(b); i++ {
if a[i] != b[i] {
break
}
}
return a[:i]
}
// copySlice is a helper for copying slice if needed.
func copySlice(a []byte) []byte {
b := make([]byte, len(a))
copy(b, a)
return b
}
// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part.
func toNibbles(path []byte) []byte {
result := make([]byte, len(path)*2)
for i := range path {
result[i*2] = path[i] >> 4
result[i*2+1] = path[i] & 0x0F
}
return result
}
// ToNeoStorageKey converts storage key to C# neo node's format.
// Key is expected to be at least 20 bytes in length.
// our format: script hash in BE + key
// neo format: script hash in LE + key with 0 between every 16 bytes, padded to len 16.
func ToNeoStorageKey(key []byte) []byte {
const groupSize = 16
var nkey []byte
for i := util.Uint160Size - 1; i >= 0; i-- {
nkey = append(nkey, key[i])
}
key = key[util.Uint160Size:]
index := 0
remain := len(key)
for remain >= groupSize {
nkey = append(nkey, key[index:index+groupSize]...)
nkey = append(nkey, 0)
index += groupSize
remain -= groupSize
}
if remain > 0 {
nkey = append(nkey, key[index:]...)
}
padding := groupSize - remain
for i := 0; i < padding; i++ {
nkey = append(nkey, 0)
}
return append(nkey, byte(padding))
}
// ToNeoStorageValue serializes si to a C# neo node's format.
// It has additional version (0x00) byte at the beginning.
func ToNeoStorageValue(si *state.StorageItem) []byte {
const version = 0
buf := io.NewBufBinWriter()
buf.BinWriter.WriteB(version)
si.EncodeBinary(buf.BinWriter)
return buf.Bytes()
}

View file

@ -0,0 +1,30 @@
package mpt
import (
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
)
func TestToNeoStorageKey(t *testing.T) {
testCases := []struct{ key, res string }{
{
"0102030405060708091011121314151617181920",
"20191817161514131211100908070605040302010000000000000000000000000000000010",
},
{
"010203040506070809101112131415161718192021222324",
"2019181716151413121110090807060504030201212223240000000000000000000000000c",
},
{
"0102030405060708091011121314151617181920212223242526272829303132333435363738",
"20191817161514131211100908070605040302012122232425262728293031323334353600373800000000000000000000000000000e",
},
}
for _, tc := range testCases {
key, _ := hex.DecodeString(tc.key)
res, _ := hex.DecodeString(tc.res)
require.Equal(t, res, ToNeoStorageKey(key))
}
}

73
pkg/core/mpt/leaf.go Normal file
View file

@ -0,0 +1,73 @@
package mpt
import (
"encoding/hex"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxValueLength is a max length of a leaf node value.
const MaxValueLength = 1024 * 1024
// LeafNode represents MPT's leaf node.
type LeafNode struct {
BaseNode
value []byte
}
var _ Node = (*LeafNode)(nil)
// NewLeafNode returns hash node with the specified value.
func NewLeafNode(value []byte) *LeafNode {
return &LeafNode{value: value}
}
// Type implements Node interface.
func (n LeafNode) Type() NodeType { return LeafT }
// Hash implements BaseNode interface.
func (n *LeafNode) Hash() util.Uint256 {
return n.getHash(n)
}
// Bytes implements BaseNode interface.
func (n *LeafNode) Bytes() []byte {
return n.getBytes(n)
}
// DecodeBinary implements io.Serializable.
func (n *LeafNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
if sz > MaxValueLength {
r.Err = fmt.Errorf("leaf node value is too big: %d", sz)
return
}
n.value = make([]byte, sz)
r.ReadBytes(n.value)
n.invalidateCache()
}
// EncodeBinary implements io.Serializable.
func (n LeafNode) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(n.value)
}
// MarshalJSON implements json.Marshaler.
func (n *LeafNode) MarshalJSON() ([]byte, error) {
return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *LeafNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*LeafNode); ok {
*n = *u
return nil
}
return errors.New("expected leaf node")
}

134
pkg/core/mpt/node.go Normal file
View file

@ -0,0 +1,134 @@
package mpt
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// NodeType represents node type..
type NodeType byte
// Node types definitions.
const (
BranchT NodeType = 0x00
ExtensionT NodeType = 0x01
HashT NodeType = 0x02
LeafT NodeType = 0x03
)
// NodeObject represents Node together with it's type.
// It is used for serialization/deserialization where type info
// is also expected.
type NodeObject struct {
Node
}
// Node represents common interface of all MPT nodes.
type Node interface {
io.Serializable
json.Marshaler
json.Unmarshaler
BaseNodeIface
}
// EncodeBinary implements io.Serializable.
func (n NodeObject) EncodeBinary(w *io.BinWriter) {
encodeNodeWithType(n.Node, w)
}
// DecodeBinary implements io.Serializable.
func (n *NodeObject) DecodeBinary(r *io.BinReader) {
typ := NodeType(r.ReadB())
switch typ {
case BranchT:
n.Node = new(BranchNode)
case ExtensionT:
n.Node = new(ExtensionNode)
case HashT:
n.Node = new(HashNode)
case LeafT:
n.Node = new(LeafNode)
default:
r.Err = fmt.Errorf("invalid node type: %x", typ)
return
}
n.Node.DecodeBinary(r)
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *NodeObject) UnmarshalJSON(data []byte) error {
var m map[string]json.RawMessage
err := json.Unmarshal(data, &m)
if err != nil { // it can be a branch node
var nodes []NodeObject
if err := json.Unmarshal(data, &nodes); err != nil {
return err
} else if len(nodes) != childrenCount {
return errors.New("invalid length of branch node")
}
b := NewBranchNode()
for i := range b.Children {
b.Children[i] = nodes[i].Node
}
n.Node = b
return nil
}
switch len(m) {
case 0:
n.Node = new(HashNode)
case 1:
if v, ok := m["hash"]; ok {
var h util.Uint256
if err := json.Unmarshal(v, &h); err != nil {
return err
}
n.Node = NewHashNode(h)
} else if v, ok = m["value"]; ok {
b, err := unmarshalHex(v)
if err != nil {
return err
} else if len(b) > MaxValueLength {
return errors.New("leaf value is too big")
}
n.Node = NewLeafNode(b)
} else {
return errors.New("invalid field")
}
case 2:
keyRaw, ok1 := m["key"]
nextRaw, ok2 := m["next"]
if !ok1 || !ok2 {
return errors.New("invalid field")
}
key, err := unmarshalHex(keyRaw)
if err != nil {
return err
} else if len(key) > MaxKeyLength {
return errors.New("extension key is too big")
}
var next NodeObject
if err := json.Unmarshal(nextRaw, &next); err != nil {
return err
}
n.Node = NewExtensionNode(key, next.Node)
default:
return errors.New("0, 1 or 2 fields expected")
}
return nil
}
func unmarshalHex(data json.RawMessage) ([]byte, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return hex.DecodeString(s)
}

156
pkg/core/mpt/node_test.go Normal file
View file

@ -0,0 +1,156 @@
package mpt
import (
"encoding/json"
"testing"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
return func(t *testing.T) {
t.Run("IO", func(t *testing.T) {
bs, err := testserdes.EncodeBinary(expected)
require.NoError(t, err)
err = testserdes.DecodeBinary(bs, actual)
if !ok {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, expected.Type(), actual.Type())
require.Equal(t, expected.Hash(), actual.Hash())
})
t.Run("JSON", func(t *testing.T) {
bs, err := json.Marshal(expected)
require.NoError(t, err)
err = json.Unmarshal(bs, actual)
if !ok {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, expected.Type(), actual.Type())
require.Equal(t, expected.Hash(), actual.Hash())
})
}
}
func TestNode_Serializable(t *testing.T) {
t.Run("Leaf", func(t *testing.T) {
t.Run("Good", func(t *testing.T) {
l := NewLeafNode(random.Bytes(123))
t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject)))
})
t.Run("BigValue", getTestFuncEncode(false,
NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode)))
})
t.Run("Extension", func(t *testing.T) {
t.Run("Good", func(t *testing.T) {
e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10)))
t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject)))
})
t.Run("BigKey", getTestFuncEncode(false,
NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode)))
})
t.Run("Branch", func(t *testing.T) {
b := NewBranchNode()
b.Children[0] = NewLeafNode(random.Bytes(10))
b.Children[lastChild] = NewHashNode(random.Uint256())
t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject)))
})
t.Run("Hash", func(t *testing.T) {
t.Run("Good", func(t *testing.T) {
h := NewHashNode(random.Uint256())
t.Run("Raw", getTestFuncEncode(true, h, new(HashNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject)))
})
t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes
testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode))
})
t.Run("InvalidSize", func(t *testing.T) {
buf := io.NewBufBinWriter()
buf.BinWriter.WriteVarBytes(make([]byte, 13))
require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode)))
})
})
t.Run("Invalid", func(t *testing.T) {
require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject)))
})
}
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198
func TestJSONSharp(t *testing.T) {
tr := NewTrie(nil, newTestStore())
require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11}))
require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22}))
require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac}))
require.NoError(t, tr.Delete([]byte{0xac, 0x11}))
require.NoError(t, tr.Delete([]byte{0xac, 0x22}))
js, err := tr.root.MarshalJSON()
require.NoError(t, err)
require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js))
}
func TestInvalidJSON(t *testing.T) {
t.Run("InvalidChildrenCount", func(t *testing.T) {
var cs [childrenCount + 1]Node
for i := range cs {
cs[i] = new(HashNode)
}
data, err := json.Marshal(cs)
require.NoError(t, err)
var n NodeObject
require.Error(t, json.Unmarshal(data, &n))
})
testCases := []struct {
name string
data []byte
}{
{"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)},
{"InvalidField1", []byte(`{"next":{}}`)},
{"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)},
{"InvalidKey", []byte(`{"key":"xy", "next":{}}`)},
{"InvalidNext", []byte(`{"key":"01", "next":[]}`)},
{"InvalidHash", []byte(`{"hash":"01"}`)},
{"InvalidValue", []byte(`{"value":1}`)},
{"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)},
}
for _, tc := range testCases {
var n NodeObject
assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name)
}
}
// C# interoperability test
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135
func TestRootHash(t *testing.T) {
b := NewBranchNode()
r := NewExtensionNode([]byte{0x0A, 0x0C}, b)
v1 := NewLeafNode([]byte{0xAB, 0xCD})
l1 := NewExtensionNode([]byte{0x01}, v1)
b.Children[0] = l1
v2 := NewLeafNode([]byte{0x22, 0x22})
l2 := NewExtensionNode([]byte{0x09}, v2)
b.Children[9] = l2
r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1)
require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE())
require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE())
}

74
pkg/core/mpt/proof.go Normal file
View file

@ -0,0 +1,74 @@
package mpt
import (
"bytes"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// GetProof returns a proof that key belongs to t.
// Proof consist of serialized nodes occuring on path from the root to the leaf of key.
func (t *Trie) GetProof(key []byte) ([][]byte, error) {
var proof [][]byte
path := toNibbles(key)
r, err := t.getProof(t.root, path, &proof)
if err != nil {
return proof, err
}
t.root = r
return proof, nil
}
func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
if len(path) == 0 {
*proofs = append(*proofs, copySlice(n.Bytes()))
return n, nil
}
case *BranchNode:
*proofs = append(*proofs, copySlice(n.Bytes()))
i, path := splitPath(path)
r, err := t.getProof(n.Children[i], path, proofs)
if err != nil {
return nil, err
}
n.Children[i] = r
return n, nil
case *ExtensionNode:
if bytes.HasPrefix(path, n.key) {
*proofs = append(*proofs, copySlice(n.Bytes()))
r, err := t.getProof(n.next, path[len(n.key):], proofs)
if err != nil {
return nil, err
}
n.next = r
return n, nil
}
case *HashNode:
if !n.IsEmpty() {
r, err := t.getFromStore(n.Hash())
if err != nil {
return nil, err
}
return t.getProof(r, path, proofs)
}
}
return nil, ErrNotFound
}
// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash.
// It also returns value for the key.
func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) {
path := toNibbles(key)
tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore()))
for i := range proofs {
h := hash.DoubleSha256(proofs[i])
// no errors in Put to memory store
_ = tr.Store.Put(makeStorageKey(h[:]), proofs[i])
}
_, bs, err := tr.getWithPath(tr.root, path)
return bs, err == nil
}

View file

@ -0,0 +1,73 @@
package mpt
import (
"testing"
"github.com/stretchr/testify/require"
)
func newProofTrie(t *testing.T) *Trie {
l := NewLeafNode([]byte("somevalue"))
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
l2 := NewLeafNode([]byte("invalid"))
e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash()))
b := NewBranchNode()
b.Children[4] = NewHashNode(e.Hash())
b.Children[5] = e2
tr := NewTrie(b, newTestStore())
require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1")))
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
tr.putToStore(l)
tr.putToStore(e)
return tr
}
func TestTrie_GetProof(t *testing.T) {
tr := newProofTrie(t)
t.Run("MissingKey", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x12})
require.Error(t, err)
})
t.Run("Valid", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x12, 0x31})
require.NoError(t, err)
})
t.Run("MissingHashNode", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x55})
require.Error(t, err)
})
}
func TestVerifyProof(t *testing.T) {
tr := newProofTrie(t)
t.Run("Simple", func(t *testing.T) {
proof, err := tr.GetProof([]byte{0x12, 0x32})
require.NoError(t, err)
t.Run("Good", func(t *testing.T) {
v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof)
require.True(t, ok)
require.Equal(t, []byte("value2"), v)
})
t.Run("Bad", func(t *testing.T) {
_, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof)
require.False(t, ok)
})
})
t.Run("InsideHash", func(t *testing.T) {
key := []byte{0x45, 0x67}
proof, err := tr.GetProof(key)
require.NoError(t, err)
v, ok := VerifyProof(tr.root.Hash(), key, proof)
require.True(t, ok)
require.Equal(t, []byte("somevalue"), v)
})
}

390
pkg/core/mpt/trie.go Normal file
View file

@ -0,0 +1,390 @@
package mpt
import (
"bytes"
"errors"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// Trie is an MPT trie storing all key-value pairs.
type Trie struct {
Store *storage.MemCachedStore
root Node
}
// ErrNotFound is returned when requested trie item is missing.
var ErrNotFound = errors.New("item not found")
// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors
// so that all storage errors are processed during `store.Persist()` at the caller.
// This also has the benefit, that every `Put` can be considered an atomic operation.
func NewTrie(root Node, store *storage.MemCachedStore) *Trie {
if root == nil {
root = new(HashNode)
}
return &Trie{
Store: store,
root: root,
}
}
// Get returns value for the provided key in t.
func (t *Trie) Get(key []byte) ([]byte, error) {
path := toNibbles(key)
r, bs, err := t.getWithPath(t.root, path)
if err != nil {
return nil, err
}
t.root = r
return bs, nil
}
// getWithPath returns value the provided path in a subtrie rooting in curr.
// It also returns a current node with all hash nodes along the path
// replaced to their "unhashed" counterparts.
func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) {
switch n := curr.(type) {
case *LeafNode:
if len(path) == 0 {
return curr, copySlice(n.value), nil
}
case *BranchNode:
i, path := splitPath(path)
r, bs, err := t.getWithPath(n.Children[i], path)
if err != nil {
return nil, nil, err
}
n.Children[i] = r
return n, bs, nil
case *HashNode:
if !n.IsEmpty() {
if r, err := t.getFromStore(n.hash); err == nil {
return t.getWithPath(r, path)
}
}
case *ExtensionNode:
if bytes.HasPrefix(path, n.key) {
r, bs, err := t.getWithPath(n.next, path[len(n.key):])
if err != nil {
return nil, nil, err
}
n.next = r
return curr, bs, err
}
default:
panic("invalid MPT node type")
}
return curr, nil, ErrNotFound
}
// Put puts key-value pair in t.
func (t *Trie) Put(key, value []byte) error {
if len(key) > MaxKeyLength {
return errors.New("key is too big")
} else if len(value) > MaxValueLength {
return errors.New("value is too big")
}
if len(value) == 0 {
return t.Delete(key)
}
path := toNibbles(key)
n := NewLeafNode(value)
r, err := t.putIntoNode(t.root, path, n)
if err != nil {
return err
}
t.root = r
return nil
}
// putIntoLeaf puts val to trie if current node is a Leaf.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
v := val.(*LeafNode)
if len(path) == 0 {
return v, nil
}
b := NewBranchNode()
b.Children[path[0]] = newSubTrie(path[1:], v)
b.Children[lastChild] = curr
return b, nil
}
// putIntoBranch puts val to trie if current node is a Branch.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
i, path := splitPath(path)
r, err := t.putIntoNode(curr.Children[i], path, val)
if err != nil {
return nil, err
}
curr.Children[i] = r
curr.invalidateCache()
return curr, nil
}
// putIntoExtension puts val to trie if current node is an Extension.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
if bytes.HasPrefix(path, curr.key) {
r, err := t.putIntoNode(curr.next, path[len(curr.key):], val)
if err != nil {
return nil, err
}
curr.next = r
curr.invalidateCache()
return curr, nil
}
pref := lcp(curr.key, path)
lp := len(pref)
keyTail := curr.key[lp:]
pathTail := path[lp:]
s1 := newSubTrie(keyTail[1:], curr.next)
b := NewBranchNode()
b.Children[keyTail[0]] = s1
i, pathTail := splitPath(pathTail)
s2 := newSubTrie(pathTail, val)
b.Children[i] = s2
if lp > 0 {
return NewExtensionNode(copySlice(pref), b), nil
}
return b, nil
}
// putIntoHash puts val to trie if current node is a HashNode.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
if curr.IsEmpty() {
return newSubTrie(path, val), nil
}
result, err := t.getFromStore(curr.hash)
if err != nil {
return nil, err
}
return t.putIntoNode(result, path, val)
}
// newSubTrie create new trie containing node at provided path.
func newSubTrie(path []byte, val Node) Node {
if len(path) == 0 {
return val
}
return NewExtensionNode(path, val)
}
func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
return t.putIntoLeaf(n, path, val)
case *BranchNode:
return t.putIntoBranch(n, path, val)
case *ExtensionNode:
return t.putIntoExtension(n, path, val)
case *HashNode:
return t.putIntoHash(n, path, val)
default:
panic("invalid MPT node type")
}
}
// Delete removes key from trie.
// It returns no error on missing key.
func (t *Trie) Delete(key []byte) error {
path := toNibbles(key)
r, err := t.deleteFromNode(t.root, path)
if err != nil {
return err
}
t.root = r
return nil
}
func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
i, path := splitPath(path)
r, err := t.deleteFromNode(b.Children[i], path)
if err != nil {
return nil, err
}
b.Children[i] = r
b.invalidateCache()
var count, index int
for i := range b.Children {
h, ok := b.Children[i].(*HashNode)
if !ok || !h.IsEmpty() {
index = i
count++
}
}
// count is >= 1 because branch node had at least 2 children before deletion.
if count > 1 {
return b, nil
}
c := b.Children[index]
if index == lastChild {
return c, nil
}
if h, ok := c.(*HashNode); ok {
c, err = t.getFromStore(h.Hash())
if err != nil {
return nil, err
}
}
if e, ok := c.(*ExtensionNode); ok {
e.key = append([]byte{byte(index)}, e.key...)
e.invalidateCache()
return e, nil
}
return NewExtensionNode([]byte{byte(index)}, c), nil
}
func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) {
if !bytes.HasPrefix(path, n.key) {
return nil, ErrNotFound
}
r, err := t.deleteFromNode(n.next, path[len(n.key):])
if err != nil {
return nil, err
}
switch nxt := r.(type) {
case *ExtensionNode:
n.key = append(n.key, nxt.key...)
n.next = nxt.next
case *HashNode:
if nxt.IsEmpty() {
return nxt, nil
}
default:
n.next = r
}
n.invalidateCache()
return n, nil
}
func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
if len(path) == 0 {
return new(HashNode), nil
}
return nil, ErrNotFound
case *BranchNode:
return t.deleteFromBranch(n, path)
case *ExtensionNode:
return t.deleteFromExtension(n, path)
case *HashNode:
if n.IsEmpty() {
return nil, ErrNotFound
}
newNode, err := t.getFromStore(n.Hash())
if err != nil {
return nil, err
}
return t.deleteFromNode(newNode, path)
default:
panic("invalid MPT node type")
}
}
// StateRoot returns root hash of t.
func (t *Trie) StateRoot() util.Uint256 {
if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() {
return util.Uint256{}
}
return t.root.Hash()
}
func makeStorageKey(mptKey []byte) []byte {
return append([]byte{byte(storage.DataMPT)}, mptKey...)
}
// Flush puts every node in the trie except Hash ones to the storage.
// Because we care only about block-level changes, there is no need to put every
// new node to storage. Normally, flush should be called with every StateRoot persist, i.e.
// after every block.
func (t *Trie) Flush() {
t.flush(t.root)
}
func (t *Trie) flush(node Node) {
if node.IsFlushed() {
return
}
switch n := node.(type) {
case *BranchNode:
for i := range n.Children {
t.flush(n.Children[i])
}
case *ExtensionNode:
t.flush(n.next)
case *HashNode:
return
}
t.putToStore(node)
}
func (t *Trie) putToStore(n Node) {
if n.Type() == HashT {
panic("can't put hash node in trie")
}
_ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors
n.SetFlushed()
}
func (t *Trie) getFromStore(h util.Uint256) (Node, error) {
data, err := t.Store.Get(makeStorageKey(h.BytesBE()))
if err != nil {
return nil, err
}
var n NodeObject
r := io.NewBinReaderFromBuf(data)
n.DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
}
return n.Node, nil
}
// Collapse compresses all nodes at depth n to the hash nodes.
// Note: this function does not perform any kind of storage flushing so
// `Flush()` should be called explicitly before invoking function.
func (t *Trie) Collapse(depth int) {
if depth < 0 {
panic("negative depth")
}
t.root = collapse(depth, t.root)
}
func collapse(depth int, node Node) Node {
if _, ok := node.(*HashNode); ok {
return node
} else if depth == 0 {
return NewHashNode(node.Hash())
}
switch n := node.(type) {
case *BranchNode:
for i := range n.Children {
n.Children[i] = collapse(depth-1, n.Children[i])
}
case *ExtensionNode:
n.next = collapse(depth-1, n.next)
case *LeafNode:
case *HashNode:
default:
panic("invalid MPT node type")
}
return node
}

446
pkg/core/mpt/trie_test.go Normal file
View file

@ -0,0 +1,446 @@
package mpt
import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/stretchr/testify/require"
)
func newTestStore() *storage.MemCachedStore {
return storage.NewMemCachedStore(storage.NewMemoryStore())
}
func newTestTrie(t *testing.T) *Trie {
b := NewBranchNode()
l1 := NewLeafNode([]byte{0xAB, 0xCD})
b.Children[0] = NewExtensionNode([]byte{0x01}, l1)
l2 := NewLeafNode([]byte{0x22, 0x22})
b.Children[9] = NewExtensionNode([]byte{0x09}, l2)
v := NewLeafNode([]byte("hello"))
h := NewHashNode(v.Hash())
b.Children[10] = NewExtensionNode([]byte{0x0e}, h)
e := NewExtensionNode(toNibbles([]byte{0xAC}), b)
tr := NewTrie(e, newTestStore())
tr.putToStore(e)
tr.putToStore(b)
tr.putToStore(l1)
tr.putToStore(l2)
tr.putToStore(v)
tr.putToStore(b.Children[0])
tr.putToStore(b.Children[9])
tr.putToStore(b.Children[10])
return tr
}
func TestTrie_PutIntoBranchNode(t *testing.T) {
b := NewBranchNode()
l := NewLeafNode([]byte{0x8})
b.Children[0x7] = NewHashNode(l.Hash())
b.Children[0x8] = NewHashNode(random.Uint256())
tr := NewTrie(b, newTestStore())
// next
require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34}))
tr.testHas(t, []byte{}, []byte{0x12, 0x34})
// empty hash node child
require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56}))
tr.testHas(t, []byte{0x66}, []byte{0x56})
require.True(t, isValid(tr.root))
// missing hash
require.Error(t, tr.Put([]byte{0x70}, []byte{0x42}))
require.True(t, isValid(tr.root))
// hash is in store
tr.putToStore(l)
require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42}))
require.True(t, isValid(tr.root))
}
func TestTrie_PutIntoExtensionNode(t *testing.T) {
l := NewLeafNode([]byte{0x11})
key := []byte{0x12}
e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash()))
tr := NewTrie(e, newTestStore())
// missing hash
require.Error(t, tr.Put(key, []byte{0x42}))
tr.putToStore(l)
require.NoError(t, tr.Put(key, []byte{0x42}))
tr.testHas(t, key, []byte{0x42})
require.True(t, isValid(tr.root))
}
func TestTrie_PutIntoHashNode(t *testing.T) {
b := NewBranchNode()
l := NewLeafNode(random.Bytes(5))
e := NewExtensionNode([]byte{0x02}, l)
b.Children[1] = NewHashNode(e.Hash())
b.Children[9] = NewHashNode(random.Uint256())
tr := NewTrie(b, newTestStore())
tr.putToStore(e)
t.Run("MissingLeafHash", func(t *testing.T) {
_, err := tr.Get([]byte{0x12})
require.Error(t, err)
})
tr.putToStore(l)
val := random.Bytes(3)
require.NoError(t, tr.Put([]byte{0x12, 0x34}, val))
tr.testHas(t, []byte{0x12, 0x34}, val)
tr.testHas(t, []byte{0x12}, l.value)
require.True(t, isValid(tr.root))
}
func TestTrie_Put(t *testing.T) {
trExp := newTestTrie(t)
trAct := NewTrie(nil, newTestStore())
require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD}))
require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22}))
require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello")))
// Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie.
require.Equal(t, trExp.root.Hash(), trAct.root.Hash())
require.True(t, isValid(trAct.root))
}
func TestTrie_PutInvalid(t *testing.T) {
tr := NewTrie(nil, newTestStore())
key, value := []byte("key"), []byte("value")
// big key
require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value))
// big value
require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1)))
// this is ok though
require.NoError(t, tr.Put(key, value))
tr.testHas(t, key, value)
}
func TestTrie_BigPut(t *testing.T) {
tr := NewTrie(nil, newTestStore())
items := []struct{ k, v string }{
{"item with long key", "value1"},
{"item with matching prefix", "value2"},
{"another prefix", "value3"},
{"another prefix 2", "value4"},
{"another ", "value5"},
}
for i := range items {
require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v)))
}
for i := range items {
tr.testHas(t, []byte(items[i].k), []byte(items[i].v))
}
t.Run("Rewrite", func(t *testing.T) {
k, v := []byte(items[0].k), []byte{0x01, 0x23}
require.NoError(t, tr.Put(k, v))
tr.testHas(t, k, v)
})
t.Run("Remove", func(t *testing.T) {
k := []byte(items[1].k)
require.NoError(t, tr.Put(k, []byte{}))
tr.testHas(t, k, nil)
})
}
func (tr *Trie) testHas(t *testing.T, key, value []byte) {
v, err := tr.Get(key)
if value == nil {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, value, v)
}
// isValid checks for 3 invariants:
// - BranchNode contains > 1 children
// - ExtensionNode do not contain another extension node
// - ExtensionNode do not have nil key
// It is used only during testing to catch possible bugs.
func isValid(curr Node) bool {
switch n := curr.(type) {
case *BranchNode:
var count int
for i := range n.Children {
if !isValid(n.Children[i]) {
return false
}
hn, ok := n.Children[i].(*HashNode)
if !ok || !hn.IsEmpty() {
count++
}
}
return count > 1
case *ExtensionNode:
_, ok := n.next.(*ExtensionNode)
return len(n.key) != 0 && !ok
default:
return true
}
}
func TestTrie_Get(t *testing.T) {
t.Run("HashNode", func(t *testing.T) {
tr := newTestTrie(t)
tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello"))
})
t.Run("UnfoldRoot", func(t *testing.T) {
tr := newTestTrie(t)
single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store)
single.testHas(t, []byte{0xAC}, nil)
single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD})
single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22})
single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello"))
})
}
func TestTrie_Flush(t *testing.T) {
pairs := map[string][]byte{
"": []byte("value0"),
"key1": []byte("value1"),
"key2": []byte("value2"),
}
tr := NewTrie(nil, newTestStore())
for k, v := range pairs {
require.NoError(t, tr.Put([]byte(k), v))
}
tr.Flush()
tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store)
for k, v := range pairs {
actual, err := tr.Get([]byte(k))
require.NoError(t, err)
require.Equal(t, v, actual)
}
}
func TestTrie_Delete(t *testing.T) {
t.Run("Hash", func(t *testing.T) {
t.Run("FromStore", func(t *testing.T) {
l := NewLeafNode([]byte{0x12})
tr := NewTrie(NewHashNode(l.Hash()), newTestStore())
t.Run("NotInStore", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{}))
})
tr.putToStore(l)
tr.testHas(t, []byte{}, []byte{0x12})
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
})
t.Run("Empty", func(t *testing.T) {
tr := NewTrie(nil, newTestStore())
require.Error(t, tr.Delete([]byte{}))
})
})
t.Run("Leaf", func(t *testing.T) {
l := NewLeafNode([]byte{0x12, 0x34})
tr := NewTrie(l, newTestStore())
t.Run("NonExistentKey", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{0x12}))
tr.testHas(t, []byte{}, []byte{0x12, 0x34})
})
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
})
t.Run("Extension", func(t *testing.T) {
t.Run("SingleKey", func(t *testing.T) {
l := NewLeafNode([]byte{0x12, 0x34})
e := NewExtensionNode([]byte{0x0A, 0x0B}, l)
tr := NewTrie(e, newTestStore())
t.Run("NonExistentKey", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34})
})
require.NoError(t, tr.Delete([]byte{0xAB}))
require.True(t, tr.root.(*HashNode).IsEmpty())
})
t.Run("MultipleKeys", func(t *testing.T) {
b := NewBranchNode()
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34}))
b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78}))
e := NewExtensionNode([]byte{0x01, 0x02}, b)
tr := NewTrie(e, newTestStore())
h := e.Hash()
require.NoError(t, tr.Delete([]byte{0x12, 0x01}))
tr.testHas(t, []byte{0x12, 0x01}, nil)
tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78})
require.NotEqual(t, h, tr.root.Hash())
require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key)
require.IsType(t, (*LeafNode)(nil), e.next)
})
})
t.Run("Branch", func(t *testing.T) {
t.Run("3 Children", func(t *testing.T) {
b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12})
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34}))
b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56}))
tr := NewTrie(b, newTestStore())
require.NoError(t, tr.Delete([]byte{0x16}))
tr.testHas(t, []byte{}, []byte{0x12})
tr.testHas(t, []byte{0x01}, []byte{0x34})
tr.testHas(t, []byte{0x16}, nil)
})
t.Run("2 Children", func(t *testing.T) {
newt := func(t *testing.T) *Trie {
b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12})
l := NewLeafNode([]byte{0x34})
e := NewExtensionNode([]byte{0x06}, l)
b.Children[5] = NewHashNode(e.Hash())
tr := NewTrie(b, newTestStore())
tr.putToStore(l)
tr.putToStore(e)
return tr
}
t.Run("DeleteLast", func(t *testing.T) {
t.Run("MergeExtension", func(t *testing.T) {
tr := newt(t)
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
tr.testHas(t, []byte{0x56}, []byte{0x34})
require.IsType(t, (*ExtensionNode)(nil), tr.root)
})
t.Run("LeaveLeaf", func(t *testing.T) {
c := NewBranchNode()
c.Children[5] = NewLeafNode([]byte{0x05})
c.Children[6] = NewLeafNode([]byte{0x06})
b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12})
b.Children[5] = c
tr := NewTrie(b, newTestStore())
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
tr.testHas(t, []byte{0x55}, []byte{0x05})
tr.testHas(t, []byte{0x56}, []byte{0x06})
require.IsType(t, (*ExtensionNode)(nil), tr.root)
})
})
t.Run("DeleteMiddle", func(t *testing.T) {
tr := newt(t)
require.NoError(t, tr.Delete([]byte{0x56}))
tr.testHas(t, []byte{}, []byte{0x12})
tr.testHas(t, []byte{0x56}, nil)
require.IsType(t, (*LeafNode)(nil), tr.root)
})
})
})
}
func TestTrie_PanicInvalidRoot(t *testing.T) {
tr := &Trie{Store: newTestStore()}
require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) })
require.Panics(t, func() { _, _ = tr.Get([]byte{1}) })
require.Panics(t, func() { _ = tr.Delete([]byte{1}) })
}
func TestTrie_Collapse(t *testing.T) {
t.Run("PanicNegative", func(t *testing.T) {
tr := newTestTrie(t)
require.Panics(t, func() { tr.Collapse(-1) })
})
t.Run("Depth=0", func(t *testing.T) {
tr := newTestTrie(t)
h := tr.root.Hash()
_, ok := tr.root.(*HashNode)
require.False(t, ok)
tr.Collapse(0)
_, ok = tr.root.(*HashNode)
require.True(t, ok)
require.Equal(t, h, tr.root.Hash())
})
t.Run("Branch,Depth=1", func(t *testing.T) {
b := NewBranchNode()
e := NewExtensionNode([]byte{0x01}, NewLeafNode([]byte("value1")))
he := e.Hash()
b.Children[0] = e
hb := b.Hash()
tr := NewTrie(b, newTestStore())
tr.Collapse(1)
newb, ok := tr.root.(*BranchNode)
require.True(t, ok)
require.Equal(t, hb, newb.Hash())
require.IsType(t, (*HashNode)(nil), b.Children[0])
require.Equal(t, he, b.Children[0].Hash())
})
t.Run("Extension,Depth=1", func(t *testing.T) {
l := NewLeafNode([]byte("value"))
hl := l.Hash()
e := NewExtensionNode([]byte{0x01}, l)
h := e.Hash()
tr := NewTrie(e, newTestStore())
tr.Collapse(1)
newe, ok := tr.root.(*ExtensionNode)
require.True(t, ok)
require.Equal(t, h, newe.Hash())
require.IsType(t, (*HashNode)(nil), newe.next)
require.Equal(t, hl, newe.next.Hash())
})
t.Run("Leaf", func(t *testing.T) {
l := NewLeafNode([]byte("value"))
tr := NewTrie(l, newTestStore())
tr.Collapse(10)
require.Equal(t, NewLeafNode([]byte("value")), tr.root)
})
t.Run("Hash", func(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
tr := NewTrie(new(HashNode), newTestStore())
require.NotPanics(t, func() { tr.Collapse(1) })
hn, ok := tr.root.(*HashNode)
require.True(t, ok)
require.True(t, hn.IsEmpty())
})
h := random.Uint256()
hn := NewHashNode(h)
tr := NewTrie(hn, newTestStore())
tr.Collapse(10)
newRoot, ok := tr.root.(*HashNode)
require.True(t, ok)
require.Equal(t, NewHashNode(h), newRoot)
})
}

View file

@ -30,6 +30,14 @@ var (
Namespace: "neogo",
},
)
//stateHeight prometheus metric.
stateHeight = prometheus.NewGauge(
prometheus.GaugeOpts{
Help: "Current verified state height",
Name: "current_state_height",
Namespace: "neogo",
},
)
)
func init() {
@ -51,3 +59,7 @@ func updateHeaderHeightMetric(hHeight int) {
func updateBlockHeightMetric(bHeight uint32) {
blockHeight.Set(float64(bHeight))
}
func updateStateHeightMetric(sHeight uint32) {
stateHeight.Set(float64(sHeight))
}

146
pkg/core/state/mpt_root.go Normal file
View file

@ -0,0 +1,146 @@
package state
import (
"encoding/json"
"errors"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MPTRootBase represents storage state root.
type MPTRootBase struct {
Version byte `json:"version"`
Index uint32 `json:"index"`
PrevHash util.Uint256 `json:"prehash"`
Root util.Uint256 `json:"stateroot"`
}
// MPTRoot represents storage state root together with sign info.
type MPTRoot struct {
MPTRootBase
Witness *transaction.Witness `json:"witness,omitempty"`
}
// MPTRootStateFlag represents verification state of the state root.
type MPTRootStateFlag byte
// Possible verification states of MPTRoot.
const (
Unverified MPTRootStateFlag = 0x00
Verified MPTRootStateFlag = 0x01
Invalid MPTRootStateFlag = 0x03
)
// MPTRootState represents state root together with its verification state.
type MPTRootState struct {
MPTRoot `json:"stateroot"`
Flag MPTRootStateFlag `json:"flag"`
}
// EncodeBinary implements io.Serializable.
func (s *MPTRootState) EncodeBinary(w *io.BinWriter) {
w.WriteB(byte(s.Flag))
s.MPTRoot.EncodeBinary(w)
}
// DecodeBinary implements io.Serializable.
func (s *MPTRootState) DecodeBinary(r *io.BinReader) {
s.Flag = MPTRootStateFlag(r.ReadB())
s.MPTRoot.DecodeBinary(r)
}
// GetSignedPart returns part of MPTRootBase which needs to be signed.
func (s *MPTRootBase) GetSignedPart() []byte {
buf := io.NewBufBinWriter()
s.EncodeBinary(buf.BinWriter)
return buf.Bytes()
}
// Equals checks if s == other.
func (s *MPTRootBase) Equals(other *MPTRootBase) bool {
return s.Version == other.Version && s.Index == other.Index &&
s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root)
}
// Hash returns hash of s.
func (s *MPTRootBase) Hash() util.Uint256 {
return hash.DoubleSha256(s.GetSignedPart())
}
// DecodeBinary implements io.Serializable.
func (s *MPTRootBase) DecodeBinary(r *io.BinReader) {
s.Version = r.ReadB()
s.Index = r.ReadU32LE()
s.PrevHash.DecodeBinary(r)
s.Root.DecodeBinary(r)
}
// EncodeBinary implements io.Serializable.
func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) {
w.WriteB(s.Version)
w.WriteU32LE(s.Index)
s.PrevHash.EncodeBinary(w)
s.Root.EncodeBinary(w)
}
// DecodeBinary implements io.Serializable.
func (s *MPTRoot) DecodeBinary(r *io.BinReader) {
s.MPTRootBase.DecodeBinary(r)
var ws []transaction.Witness
r.ReadArray(&ws, 1)
if len(ws) == 1 {
s.Witness = &ws[0]
}
}
// EncodeBinary implements io.Serializable.
func (s *MPTRoot) EncodeBinary(w *io.BinWriter) {
s.MPTRootBase.EncodeBinary(w)
if s.Witness == nil {
w.WriteVarUint(0)
} else {
w.WriteArray([]*transaction.Witness{s.Witness})
}
}
// String implements fmt.Stringer.
func (f MPTRootStateFlag) String() string {
switch f {
case Unverified:
return "Unverified"
case Verified:
return "Verified"
case Invalid:
return "Invalid"
default:
return ""
}
}
// MarshalJSON implements json.Marshaler.
func (f MPTRootStateFlag) MarshalJSON() ([]byte, error) {
return []byte(`"` + f.String() + `"`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (f *MPTRootStateFlag) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
switch s {
case "Unverified":
*f = Unverified
case "Verified":
*f = Verified
case "Invalid":
*f = Invalid
default:
return errors.New("unknown flag")
}
return nil
}

View file

@ -0,0 +1,100 @@
package state
import (
"encoding/json"
"math/rand"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func testStateRoot() *MPTRoot {
return &MPTRoot{
MPTRootBase: MPTRootBase{
Version: byte(rand.Uint32()),
Index: rand.Uint32(),
PrevHash: random.Uint256(),
Root: random.Uint256(),
},
}
}
func TestStateRoot_Serializable(t *testing.T) {
r := testStateRoot()
testserdes.EncodeDecodeBinary(t, r, new(MPTRoot))
t.Run("WithWitness", func(t *testing.T) {
r.Witness = &transaction.Witness{
InvocationScript: random.Bytes(10),
VerificationScript: random.Bytes(11),
}
testserdes.EncodeDecodeBinary(t, r, new(MPTRoot))
})
}
func TestStateRootEquals(t *testing.T) {
r1 := testStateRoot()
r2 := *r1
require.True(t, r1.Equals(&r2.MPTRootBase))
r2.MPTRootBase.Index++
require.False(t, r1.Equals(&r2.MPTRootBase))
}
func TestMPTRootState_Serializable(t *testing.T) {
rs := &MPTRootState{
MPTRoot: *testStateRoot(),
Flag: 0x04,
}
rs.MPTRoot.Witness = &transaction.Witness{
InvocationScript: random.Bytes(10),
VerificationScript: random.Bytes(11),
}
testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState))
}
func TestMPTRootStateUnverifiedByDefault(t *testing.T) {
var r MPTRootState
require.Equal(t, Unverified, r.Flag)
}
func TestMPTRoot_MarshalJSON(t *testing.T) {
t.Run("Good", func(t *testing.T) {
r := testStateRoot()
rs := &MPTRootState{
MPTRoot: *r,
Flag: Verified,
}
testserdes.MarshalUnmarshalJSON(t, rs, new(MPTRootState))
})
t.Run("Compatibility", func(t *testing.T) {
js := []byte(`{
"flag": "Unverified",
"stateroot": {
"version": 1,
"index": 3000000,
"prehash": "0x4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90",
"stateroot": "0xb2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3"
}}`)
rs := new(MPTRootState)
require.NoError(t, json.Unmarshal(js, &rs))
require.EqualValues(t, 1, rs.Version)
require.EqualValues(t, 3000000, rs.Index)
require.Nil(t, rs.Witness)
u, err := util.Uint256DecodeStringLE("4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90")
require.NoError(t, err)
require.Equal(t, u, rs.PrevHash)
u, err = util.Uint256DecodeStringLE("b2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3")
require.NoError(t, err)
require.Equal(t, u, rs.Root)
})
}

View file

@ -9,6 +9,7 @@ import (
const (
DataBlock KeyPrefix = 0x01
DataTransaction KeyPrefix = 0x02
DataMPT KeyPrefix = 0x03
STAccount KeyPrefix = 0x40
STCoin KeyPrefix = 0x44
STSpentCoin KeyPrefix = 0x45

View file

@ -1,7 +1,6 @@
package keys
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/x509"
@ -10,6 +9,7 @@ import (
"fmt"
"math/big"
"github.com/btcsuite/btcd/btcec"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io"
@ -18,6 +18,9 @@ import (
"github.com/pkg/errors"
)
// coordLen is the number of bytes in serialized X or Y coordinate.
const coordLen = 32
// PublicKeys is a list of public keys.
type PublicKeys []*PublicKey
@ -94,23 +97,49 @@ func NewPublicKeyFromString(s string) (*PublicKey, error) {
return pubKey, nil
}
// Bytes returns the byte array representation of the public key.
func (p *PublicKey) Bytes() []byte {
// getBytes serializes X and Y using compressed or uncompressed format.
func (p *PublicKey) getBytes(compressed bool) []byte {
if p.IsInfinity() {
return []byte{0x00}
}
var (
x = p.X.Bytes()
paddedX = append(bytes.Repeat([]byte{0x00}, 32-len(x)), x...)
prefix = byte(0x03)
)
if p.Y.Bit(0) == 0 {
prefix = byte(0x02)
var resLen = 1 + coordLen
if !compressed {
resLen += coordLen
}
var res = make([]byte, resLen)
var prefix byte
return append([]byte{prefix}, paddedX...)
xBytes := p.X.Bytes()
copy(res[1+coordLen-len(xBytes):], xBytes)
if compressed {
if p.Y.Bit(0) == 0 {
prefix = 0x02
} else {
prefix = 0x03
}
} else {
prefix = 0x04
yBytes := p.Y.Bytes()
copy(res[1+coordLen+coordLen-len(yBytes):], yBytes)
}
res[0] = prefix
return res
}
// Bytes returns byte array representation of the public key in compressed
// form (33 bytes with 0x02 or 0x03 prefix, except infinity which is always 0).
func (p *PublicKey) Bytes() []byte {
return p.getBytes(true)
}
// UncompressedBytes returns byte array representation of the public key in
// uncompressed form (65 bytes with 0x04 prefix, except infinity which is
// always 0).
func (p *PublicKey) UncompressedBytes() []byte {
return p.getBytes(false)
}
// NewPublicKeyFromASN1 returns a NEO PublicKey from the ASN.1 serialized key.
@ -134,15 +163,25 @@ func NewPublicKeyFromASN1(data []byte) (*PublicKey, error) {
}
// decodeCompressedY performs decompression of Y coordinate for given X and Y's least significant bit.
func decodeCompressedY(x *big.Int, ylsb uint) (*big.Int, error) {
c := elliptic.P256()
cp := c.Params()
three := big.NewInt(3)
/* y**2 = x**3 + a*x + b % p */
xCubed := new(big.Int).Exp(x, three, cp.P)
threeX := new(big.Int).Mul(x, three)
threeX.Mod(threeX, cp.P)
ySquared := new(big.Int).Sub(xCubed, threeX)
// We use here a short-form Weierstrass curve (https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html)
// y² = x³ + ax + b. Two types of elliptic curves are supported:
// 1. Secp256k1 (Koblitz curve): y² = x³ + b,
// 2. Secp256r1 (Random curve): y² = x³ - 3x + b.
// To decode compressed curve point we perform the following operation: y = sqrt(x³ + ax + b mod p)
// where `p` denotes the order of the underlying curve field
func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, error) {
var a *big.Int
switch curve.(type) {
case *btcec.KoblitzCurve:
a = big.NewInt(0)
default:
a = big.NewInt(3)
}
cp := curve.Params()
xCubed := new(big.Int).Exp(x, big.NewInt(3), cp.P)
aX := new(big.Int).Mul(x, a)
aX.Mod(aX, cp.P)
ySquared := new(big.Int).Sub(xCubed, aX)
ySquared.Add(ySquared, cp.B)
ySquared.Mod(ySquared, cp.P)
y := new(big.Int).ModSqrt(ySquared, cp.P)
@ -196,7 +235,7 @@ func (p *PublicKey) DecodeBinary(r *io.BinReader) {
}
x = new(big.Int).SetBytes(xbytes)
ylsb := uint(prefix & 0x1)
y, err = decodeCompressedY(x, ylsb)
y, err = decodeCompressedY(x, ylsb, p256)
if err != nil {
r.Err = err
return
@ -306,3 +345,84 @@ func (p *PublicKey) UnmarshalJSON(data []byte) error {
return nil
}
// KeyRecover recovers public key from the given signature (r, s) on the given message hash using given elliptic curve.
// Algorithm source: SEC 1 Ver 2.0, section 4.1.6, pages 47-48 (https://www.secg.org/sec1-v2.pdf).
// Flag isEven denotes Y's least significant bit in decompression algorithm.
func KeyRecover(curve elliptic.Curve, r, s *big.Int, messageHash []byte, isEven bool) (PublicKey, error) {
var (
res PublicKey
err error
)
if r.Cmp(big.NewInt(1)) == -1 || s.Cmp(big.NewInt(1)) == -1 {
return res, errors.New("invalid signature")
}
params := curve.Params()
// Calculate h = (Q + 1 + 2 * Sqrt(Q)) / N
// num := new(big.Int).Add(new(big.Int).Add(params.P, big.NewInt(1)), new(big.Int).Mul(big.NewInt(2), new(big.Int).Sqrt(params.P)))
// h := new(big.Int).Div(num, params.N)
// We are skipping this step for secp256k1 and secp256r1 because we know cofactor of these curves (h=1)
// (see section 2.4 of http://www.secg.org/sec2-v2.pdf)
h := 1
for i := 0; i <= h; i++ {
// Step 1.1: x = (n * i) + r
Rx := new(big.Int).Mul(params.N, big.NewInt(int64(i)))
Rx.Add(Rx, r)
if Rx.Cmp(params.P) == 1 {
break
}
// Steps 1.2 and 1.3: get point R (Ry)
var R *big.Int
if isEven {
R, err = decodeCompressedY(Rx, 0, curve)
} else {
R, err = decodeCompressedY(Rx, 1, curve)
}
if err != nil {
return res, err
}
// Step 1.4: check n*R is point at infinity
nRx, nR := curve.ScalarMult(Rx, R, params.N.Bytes())
if nRx.Sign() != 0 || nR.Sign() != 0 {
continue
}
// Step 1.5: compute e
e := hashToInt(messageHash, curve)
// Step 1.6: Q = r^-1 (sR-eG)
invr := new(big.Int).ModInverse(r, params.N)
// First term.
invrS := new(big.Int).Mul(invr, s)
invrS.Mod(invrS, params.N)
sRx, sR := curve.ScalarMult(Rx, R, invrS.Bytes())
// Second term.
e.Neg(e)
e.Mod(e, params.N)
e.Mul(e, invr)
e.Mod(e, params.N)
minuseGx, minuseGy := curve.ScalarBaseMult(e.Bytes())
Qx, Qy := curve.Add(sRx, sR, minuseGx, minuseGy)
res.X = Qx
res.Y = Qy
}
return res, nil
}
// copied from crypto/ecdsa
func hashToInt(hash []byte, c elliptic.Curve) *big.Int {
orderBits := c.Params().N.BitLen()
orderBytes := (orderBits + 7) / 8
if len(hash) > orderBytes {
hash = hash[:orderBytes]
}
ret := new(big.Int).SetBytes(hash)
excess := len(hash)*8 - orderBits
if excess > 0 {
ret.Rsh(ret, uint(excess))
}
return ret
}

View file

@ -1,12 +1,16 @@
package keys
import (
"crypto/elliptic"
"encoding/hex"
"encoding/json"
"math/big"
"math/rand"
"sort"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/stretchr/testify/require"
)
@ -85,10 +89,14 @@ func TestPubkeyToAddress(t *testing.T) {
func TestDecodeBytes(t *testing.T) {
pubKey := getPubKey(t)
decodedPubKey := &PublicKey{}
err := decodedPubKey.DecodeBytes(pubKey.Bytes())
require.NoError(t, err)
require.Equal(t, pubKey, decodedPubKey)
var testBytesFunction = func(t *testing.T, bytesFunction func() []byte) {
decodedPubKey := &PublicKey{}
err := decodedPubKey.DecodeBytes(bytesFunction())
require.NoError(t, err)
require.Equal(t, pubKey, decodedPubKey)
}
t.Run("compressed", func(t *testing.T) { testBytesFunction(t, pubKey.Bytes) })
t.Run("uncompressed", func(t *testing.T) { testBytesFunction(t, pubKey.UncompressedBytes) })
}
func TestDecodeBytesBadInfinity(t *testing.T) {
@ -179,3 +187,91 @@ func TestUnmarshallJSONBadFormat(t *testing.T) {
err := json.Unmarshal([]byte(str), actual)
require.Error(t, err)
}
func TestRecoverSecp256r1(t *testing.T) {
privateKey, err := NewPrivateKey()
require.NoError(t, err)
message := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
messageHash := hash.Sha256(message).BytesBE()
signature := privateKey.Sign(message)
r := new(big.Int).SetBytes(signature[0:32])
s := new(big.Int).SetBytes(signature[32:64])
require.True(t, privateKey.PublicKey().Verify(signature, messageHash))
// To test this properly, we should provide correct isEven flag. This flag denotes which one of
// the two recovered R points in decodeCompressedY method should be chosen. Let's suppose that we
// don't know which of them suites, so to test KeyRecover we should check both and only
// one of them gives us the correct public key.
recoveredKeyFalse, err := KeyRecover(elliptic.P256(), r, s, messageHash, false)
require.NoError(t, err)
recoveredKeyTrue, err := KeyRecover(elliptic.P256(), r, s, messageHash, true)
require.NoError(t, err)
require.True(t, privateKey.PublicKey().Equal(&recoveredKeyFalse) != privateKey.PublicKey().Equal(&recoveredKeyTrue))
}
func TestRecoverSecp256r1Static(t *testing.T) {
// These data were taken from the reference KeyRecoverTest: https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_ECDsa.cs#L22
// To update this test, run the reference KeyRecover(ECCurve.Secp256r1) testcase and fetch the following data from it:
// privateKey -> b
// message -> messageHash
// signatures[0] -> r
// signatures[1] -> s
// v -> isEven
// Note, that C# BigInteger has different byte order from that used in Go.
b := []byte{123, 245, 126, 56, 3, 123, 197, 199, 26, 31, 212, 186, 120, 195, 168, 153, 57, 108, 234, 49, 107, 203, 44, 207, 185, 212, 187, 129, 74, 43, 225, 69}
privateKey, err := NewPrivateKeyFromBytes(b)
require.NoError(t, err)
messageHash := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
r := new(big.Int).SetBytes([]byte{1, 85, 226, 63, 133, 113, 217, 188, 249, 22, 213, 203, 225, 199, 32, 131, 118, 23, 28, 101, 139, 211, 13, 111, 242, 158, 193, 227, 196, 106, 3, 4})
s := new(big.Int).SetBytes([]byte{65, 174, 206, 164, 81, 34, 76, 104, 5, 49, 51, 20, 221, 183, 157, 199, 199, 47, 78, 137, 172, 99, 212, 110, 129, 72, 236, 59, 250, 81, 200, 13})
// Just ensure it's a valid signature.
require.True(t, privateKey.PublicKey().Verify(append(r.Bytes(), s.Bytes()...), messageHash))
recoveredKey, err := KeyRecover(elliptic.P256(), r, s, messageHash, false)
require.NoError(t, err)
require.True(t, privateKey.PublicKey().Equal(&recoveredKey))
}
func TestRecoverSecp256k1(t *testing.T) {
privateKey, err := btcec.NewPrivateKey(btcec.S256())
message := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
signature, err := privateKey.Sign(message)
require.NoError(t, err)
require.True(t, signature.Verify(message, privateKey.PubKey()))
// To test this properly, we should provide correct isEven flag. This flag denotes which one of
// the two recovered R points in decodeCompressedY method should be chosen. Let's suppose that we
// don't know which of them suites, so to test KeyRecover we should check both and only
// one of them gives us the correct public key.
recoveredKeyFalse, err := KeyRecover(btcec.S256(), signature.R, signature.S, message, false)
require.NoError(t, err)
recoveredKeyTrue, err := KeyRecover(btcec.S256(), signature.R, signature.S, message, true)
require.NoError(t, err)
require.True(t, (privateKey.PubKey().X.Cmp(recoveredKeyFalse.X) == 0 &&
privateKey.PubKey().Y.Cmp(recoveredKeyFalse.Y) == 0) !=
(privateKey.PubKey().X.Cmp(recoveredKeyTrue.X) == 0 &&
privateKey.PubKey().Y.Cmp(recoveredKeyTrue.Y) == 0))
}
func TestRecoverSecp256k1Static(t *testing.T) {
// These data were taken from the reference testcase: https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_ECDsa.cs#L22
// To update this test, run the reference KeyRecover(ECCurve.Secp256k1) testcase and fetch the following data from it:
// privateKey -> b
// message -> messageHash
// signatures[0] -> r
// signatures[1] -> s
// v -> isEven
// Note, that C# BigInteger has different byte order from that used in Go.
b := []byte{156, 3, 247, 58, 246, 250, 236, 27, 118, 60, 180, 177, 18, 92, 204, 206, 144, 245, 148, 141, 86, 212, 151, 181, 15, 113, 172, 180, 177, 228, 100, 32}
_, publicKey := btcec.PrivKeyFromBytes(btcec.S256(), b)
messageHash := []byte{72, 101, 108, 108, 111, 87, 111, 114, 108, 100}
r := new(big.Int).SetBytes([]byte{88, 169, 242, 111, 210, 184, 180, 46, 67, 108, 176, 77, 57, 250, 58, 36, 110, 81, 225, 65, 90, 47, 215, 91, 27, 227, 57, 6, 9, 228, 100, 50})
s := new(big.Int).SetBytes([]byte{86, 150, 81, 190, 17, 181, 212, 241, 184, 36, 136, 116, 232, 207, 46, 45, 149, 167, 15, 98, 113, 137, 66, 98, 214, 165, 38, 232, 98, 96, 79, 197})
signature := btcec.Signature{
R: r,
S: s,
}
// Just ensure it's a valid signature.
require.True(t, signature.Verify(messageHash, publicKey))
recoveredKey, err := KeyRecover(btcec.S256(), r, s, messageHash, false)
require.NoError(t, err)
require.True(t, new(big.Int).SetBytes([]byte{112, 186, 29, 131, 169, 21, 212, 95, 81, 172, 201, 145, 168, 108, 129, 90, 6, 111, 80, 39, 136, 157, 15, 181, 98, 108, 133, 108, 144, 80, 23, 225}).Cmp(recoveredKey.X) == 0)
require.True(t, new(big.Int).SetBytes([]byte{187, 102, 202, 42, 152, 133, 222, 55, 137, 228, 154, 80, 182, 35, 133, 14, 55, 165, 36, 64, 178, 55, 13, 112, 224, 143, 66, 143, 208, 18, 2, 211}).Cmp(recoveredKey.Y) == 0)
}

View file

@ -1,5 +1,5 @@
/*
Package crypto provides an interface to VM cryptographic instructions.
Package crypto provides an interface to VM cryptographic instructions and syscalls.
*/
package crypto
@ -30,3 +30,23 @@ func Hash256(b []byte) []byte {
func VerifySignature(msg []byte, sig []byte, pub []byte) bool {
return false
}
// Secp256k1Recover recovers public key from the given signature (r, s) on the
// given message hash using Secp256k1 elliptic curve. Flag isEven denotes Y's
// least significant bit in decompression algorithm. The return value is byte
// array representation of the public key which is either empty (if it's not
// possible to recover key) or contains 32 bytes in BE for X point (in case of
// success). This function uses Neo.Cryptography.Secp256k1Recover syscall.
func Secp256k1Recover(r []byte, s []byte, messageHash []byte, isEven bool) []byte {
return nil
}
// Secp256r1Recover recovers public key from the given signature (r, s) on the
// given message hash using Secp256r1 elliptic curve. Flag isEven denotes Y's
// least significant bit in decompression algorithm. The return value is byte
// array representation of the public key which is either empty (if it's not
// possible to recover key) or contains 32 bytes in BE for X point (in case of
// success). This function uses Neo.Cryptography.Secp256r1Recover syscall.
func Secp256r1Recover(r []byte, s []byte, messageHash []byte, isEven bool) []byte {
return nil
}

View file

@ -8,9 +8,9 @@ import (
"reflect"
)
// maxArraySize is a maximums size of an array which can be decoded.
// MaxArraySize is the maximum size of an array which can be decoded.
// It is taken from https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs#L130
const maxArraySize = 0x1000000
const MaxArraySize = 0x1000000
// BinReader is a convenient wrapper around a io.Reader and err object.
// Used to simplify error handling when reading into a struct with many fields.
@ -110,7 +110,7 @@ func (r *BinReader) ReadArray(t interface{}, maxSize ...int) {
elemType := sliceType.Elem()
isPtr := elemType.Kind() == reflect.Ptr
ms := maxArraySize
ms := MaxArraySize
if len(maxSize) != 0 {
ms = maxSize[0]
}
@ -168,8 +168,16 @@ func (r *BinReader) ReadVarUint() uint64 {
// ReadVarBytes reads the next set of bytes from the underlying reader.
// ReadVarUInt() is used to determine how large that slice is
func (r *BinReader) ReadVarBytes() []byte {
func (r *BinReader) ReadVarBytes(maxSize ...int) []byte {
n := r.ReadVarUint()
ms := MaxArraySize
if len(maxSize) != 0 {
ms = maxSize[0]
}
if n > uint64(ms) {
r.Err = fmt.Errorf("byte-slice is too big (%d)", n)
return nil
}
b := make([]byte, n)
r.ReadBytes(b)
return b

View file

@ -143,6 +143,35 @@ func TestBufBinWriter_Len(t *testing.T) {
require.Equal(t, 1, bw.Len())
}
func TestBinReader_ReadVarBytes(t *testing.T) {
buf := make([]byte, 11)
for i := range buf {
buf[i] = byte(i)
}
w := NewBufBinWriter()
w.WriteVarBytes(buf)
require.NoError(t, w.Err)
data := w.Bytes()
t.Run("NoArguments", func(t *testing.T) {
r := NewBinReaderFromBuf(data)
actual := r.ReadVarBytes()
require.NoError(t, r.Err)
require.Equal(t, buf, actual)
})
t.Run("Good", func(t *testing.T) {
r := NewBinReaderFromBuf(data)
actual := r.ReadVarBytes(11)
require.NoError(t, r.Err)
require.Equal(t, buf, actual)
})
t.Run("Bad", func(t *testing.T) {
r := NewBinReaderFromBuf(data)
r.ReadVarBytes(10)
require.Error(t, r.Err)
})
}
func TestWriterErrHandling(t *testing.T) {
var badio = &badRW{}
bw := NewBinWriterFromIO(badio)

View file

@ -59,6 +59,9 @@ func (chain *testChain) AddBlock(block *block.Block) error {
}
return nil
}
func (chain *testChain) AddStateRoot(r *state.MPTRoot) error {
panic("TODO")
}
func (chain *testChain) BlockHeight() uint32 {
return atomic.LoadUint32(&chain.blockheight)
}
@ -105,6 +108,12 @@ func (chain testChain) GetEnrollments() ([]*state.Validator, error) {
func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) {
panic("TODO")
}
func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) {
panic("TODO")
}
func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
panic("TODO")
}
func (chain testChain) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem {
panic("TODO")
}
@ -145,7 +154,9 @@ func (chain testChain) IsLowPriority(util.Fixed8) bool {
func (chain testChain) PoolTx(*transaction.Transaction) error {
panic("TODO")
}
func (chain testChain) StateHeight() uint32 {
panic("TODO")
}
func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) {
panic("TODO")
}

View file

@ -8,6 +8,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
@ -59,12 +60,15 @@ const (
CMDGetBlocks CommandType = "getblocks"
CMDGetData CommandType = "getdata"
CMDGetHeaders CommandType = "getheaders"
CMDGetRoots CommandType = "getroots"
CMDHeaders CommandType = "headers"
CMDInv CommandType = "inv"
CMDMempool CommandType = "mempool"
CMDMerkleBlock CommandType = "merkleblock"
CMDPing CommandType = "ping"
CMDPong CommandType = "pong"
CMDRoots CommandType = "roots"
CMDStateRoot CommandType = "stateroot"
CMDTX CommandType = "tx"
CMDUnknown CommandType = "unknown"
CMDVerack CommandType = "verack"
@ -124,6 +128,8 @@ func (m *Message) CommandType() CommandType {
return CMDGetData
case "getheaders":
return CMDGetHeaders
case "getroots":
return CMDGetRoots
case "headers":
return CMDHeaders
case "inv":
@ -136,6 +142,10 @@ func (m *Message) CommandType() CommandType {
return CMDPing
case "pong":
return CMDPong
case "roots":
return CMDRoots
case "stateroot":
return CMDStateRoot
case "tx":
return CMDTX
case "verack":
@ -191,6 +201,8 @@ func (m *Message) decodePayload(br *io.BinReader) error {
fallthrough
case CMDGetHeaders:
p = &payload.GetBlocks{}
case CMDGetRoots:
p = &payload.GetStateRoots{}
case CMDHeaders:
p = &payload.Headers{}
case CMDTX:
@ -199,6 +211,10 @@ func (m *Message) decodePayload(br *io.BinReader) error {
p = &payload.MerkleBlock{}
case CMDPing, CMDPong:
p = &payload.Ping{}
case CMDRoots:
p = &payload.StateRoots{}
case CMDStateRoot:
p = &state.MPTRoot{}
default:
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
}

View file

@ -18,6 +18,8 @@ func (i InventoryType) String() string {
return "TX"
case 0x02:
return "block"
case StateRootType:
return "stateroot"
case 0xe0:
return "consensus"
default:
@ -27,13 +29,14 @@ func (i InventoryType) String() string {
// Valid returns true if the inventory (type) is known.
func (i InventoryType) Valid() bool {
return i == BlockType || i == TXType || i == ConsensusType
return i == BlockType || i == TXType || i == ConsensusType || i == StateRootType
}
// List of valid InventoryTypes.
const (
TXType InventoryType = 0x01 // 1
BlockType InventoryType = 0x02 // 2
StateRootType InventoryType = 0x03 // 3
ConsensusType InventoryType = 0xe0 // 224
)

View file

@ -0,0 +1,43 @@
package payload
import (
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/io"
)
// MaxStateRootsAllowed is a maxumum amount of state roots
// which can be sent in a single payload.
const MaxStateRootsAllowed = 2000
// StateRoots contains multiple StateRoots.
type StateRoots struct {
Roots []state.MPTRoot
}
// GetStateRoots represents request for state roots.
type GetStateRoots struct {
Start uint32
Count uint32
}
// EncodeBinary implements io.Serializable.
func (s *StateRoots) EncodeBinary(w *io.BinWriter) {
w.WriteArray(s.Roots)
}
// DecodeBinary implements io.Serializable.
func (s *StateRoots) DecodeBinary(r *io.BinReader) {
r.ReadArray(&s.Roots, MaxStateRootsAllowed)
}
// DecodeBinary implements io.Serializable.
func (g *GetStateRoots) DecodeBinary(r *io.BinReader) {
g.Start = r.ReadU32LE()
g.Count = r.ReadU32LE()
}
// EncodeBinary implements io.Serializable.
func (g *GetStateRoots) EncodeBinary(w *io.BinWriter) {
w.WriteU32LE(g.Start)
w.WriteU32LE(g.Count)
}

View file

@ -0,0 +1,51 @@
package payload
import (
"math/rand"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
)
func TestStateRoots_Serializable(t *testing.T) {
expected := &StateRoots{
Roots: []state.MPTRoot{
{
MPTRootBase: state.MPTRootBase{
Index: rand.Uint32(),
PrevHash: random.Uint256(),
Root: random.Uint256(),
},
Witness: &transaction.Witness{
InvocationScript: random.Bytes(10),
VerificationScript: random.Bytes(11),
},
},
{
MPTRootBase: state.MPTRootBase{
Index: rand.Uint32(),
PrevHash: random.Uint256(),
Root: random.Uint256(),
},
Witness: &transaction.Witness{
InvocationScript: random.Bytes(10),
VerificationScript: random.Bytes(11),
},
},
},
}
testserdes.EncodeDecodeBinary(t, expected, new(StateRoots))
}
func TestGetStateRoots_Serializable(t *testing.T) {
expected := &GetStateRoots{
Start: rand.Uint32(),
Count: rand.Uint32(),
}
testserdes.EncodeDecodeBinary(t, expected, new(GetStateRoots))
}

View file

@ -13,6 +13,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/cache"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
@ -28,6 +30,7 @@ const (
maxBlockBatch = 200
maxAddrsToSend = 200
minPoolCount = 30
stateRootCacheSize = 100
)
var (
@ -66,6 +69,7 @@ type (
transactions chan *transaction.Transaction
stateCache cache.HashCache
consensusStarted *atomic.Bool
log *zap.Logger
@ -98,6 +102,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (*
unregister: make(chan peerDrop),
peers: make(map[Peer]bool),
consensusStarted: atomic.NewBool(false),
stateCache: *cache.NewFIFOCache(stateRootCacheSize),
log: log,
transactions: make(chan *transaction.Transaction, 64),
}
@ -469,6 +474,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
cp := s.consensus.GetPayload(h)
return cp != nil
},
payload.StateRootType: s.stateCache.Has,
}
if exists := typExists[inv.Type]; exists != nil {
for _, hash := range inv.Hashes {
@ -507,6 +513,11 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
if err == nil {
msg = s.MkMsg(CMDBlock, b)
}
case payload.StateRootType:
r := s.stateCache.Get(hash)
if r != nil {
msg = s.MkMsg(CMDStateRoot, r.(*state.MPTRoot))
}
case payload.ConsensusType:
if cp := s.consensus.GetPayload(hash); cp != nil {
msg = s.MkMsg(CMDConsensus, cp)
@ -589,6 +600,87 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error {
return p.EnqueueP2PMessage(msg)
}
// handleGetRootsCmd processees `getroots` request.
func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error {
cfg := s.chain.GetConfig()
if !cfg.EnableStateRoot || gr.Start < cfg.StateRootEnableIndex {
return nil
}
count := gr.Count
if count > payload.MaxStateRootsAllowed {
count = payload.MaxStateRootsAllowed
}
var rs payload.StateRoots
for height := gr.Start; height < gr.Start+gr.Count; height++ {
r, err := s.chain.GetStateRoot(height)
if err != nil {
return err
} else if r.Flag == state.Verified {
rs.Roots = append(rs.Roots, r.MPTRoot)
}
}
msg := s.MkMsg(CMDRoots, &rs)
return p.EnqueueP2PMessage(msg)
}
// handleStateRootsCmd processees `roots` request.
func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
h := s.chain.StateHeight()
if h < s.chain.GetConfig().StateRootEnableIndex {
h = s.chain.GetConfig().StateRootEnableIndex
}
for i := range rs.Roots {
if rs.Roots[i].Index <= h {
continue
}
_ = s.chain.AddStateRoot(&rs.Roots[i])
}
// request more state roots from peer if needed
return s.requestStateRoot(p)
}
// requestStateRoot sends `getroots` message to get verified state roots.
func (s *Server) requestStateRoot(p Peer) error {
stateHeight := s.chain.StateHeight()
hdrHeight := s.chain.BlockHeight()
enableIndex := s.chain.GetConfig().StateRootEnableIndex
if hdrHeight < enableIndex {
return nil
}
if stateHeight < enableIndex {
stateHeight = enableIndex - 1
}
count := uint32(payload.MaxStateRootsAllowed)
if diff := hdrHeight - stateHeight; diff < count {
count = diff
}
if count == 0 {
return nil
}
gr := &payload.GetStateRoots{
Start: stateHeight + 1,
Count: count,
}
return p.EnqueueP2PMessage(s.MkMsg(CMDGetRoots, gr))
}
// handleStateRootCmd processees `stateroot` request.
func (s *Server) handleStateRootCmd(r *state.MPTRoot) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
// we ignore error, because there is nothing wrong if we already have this state root
err := s.chain.AddStateRoot(r)
if err == nil && !s.stateCache.Has(r.Hash()) {
s.stateCache.Add(r)
s.broadcastMessage(s.MkMsg(CMDStateRoot, r))
}
return nil
}
// handleConsensusCmd processes received consensus payload.
// It never returns an error.
func (s *Server) handleConsensusCmd(cp *consensus.Payload) error {
@ -697,6 +789,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDGetHeaders:
gh := msg.Payload.(*payload.GetBlocks)
return s.handleGetHeadersCmd(peer, gh)
case CMDGetRoots:
gr := msg.Payload.(*payload.GetStateRoots)
return s.handleGetRootsCmd(peer, gr)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
@ -718,6 +813,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDPong:
pong := msg.Payload.(*payload.Ping)
return s.handlePong(peer, pong)
case CMDRoots:
rs := msg.Payload.(*payload.StateRoots)
return s.handleRootsCmd(peer, rs)
case CMDStateRoot:
r := msg.Payload.(*state.MPTRoot)
return s.handleStateRootCmd(r)
case CMDVersion, CMDVerack:
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
}
@ -741,11 +842,20 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return nil
}
func (s *Server) handleNewPayload(p *consensus.Payload) {
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()}))
// It's high priority because it directly affects consensus process,
// even though it's just an inv.
s.broadcastHPMessage(msg)
func (s *Server) handleNewPayload(item cache.Hashable) {
switch p := item.(type) {
case *consensus.Payload:
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()}))
// It's high priority because it directly affects consensus process,
// even though it's just an inv.
s.broadcastHPMessage(msg)
case *state.MPTRoot:
s.stateCache.Add(p)
msg := s.MkMsg(CMDStateRoot, p)
s.broadcastMessage(msg)
default:
s.log.Warn("unknown item type", zap.String("type", fmt.Sprintf("%T", p)))
}
}
func (s *Server) requestTx(hashes ...util.Uint256) {

View file

@ -251,6 +251,9 @@ func (p *TCPPeer) StartProtocol() {
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
err = p.server.requestBlocks(p)
}
if err == nil && p.server.chain.GetConfig().EnableStateRoot {
err = p.server.requestStateRoot(p)
}
if err == nil {
timer.Reset(p.server.ProtoTickInterval)
}

View file

@ -181,8 +181,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
require.NoError(t, err)
},
func(t *testing.T, p *request.Params) {
param, ok := p.Value(1)
require.Equal(t, true, ok)
param := p.Value(1)
require.NotNil(t, param)
require.Equal(t, request.TxFilterT, param.Type)
filt, ok := param.Value.(request.TxFilter)
require.Equal(t, true, ok)
@ -196,8 +196,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
require.NoError(t, err)
},
func(t *testing.T, p *request.Params) {
param, ok := p.Value(1)
require.Equal(t, true, ok)
param := p.Value(1)
require.NotNil(t, param)
require.Equal(t, request.NotificationFilterT, param.Type)
filt, ok := param.Value.(request.NotificationFilter)
require.Equal(t, true, ok)
@ -211,8 +211,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
require.NoError(t, err)
},
func(t *testing.T, p *request.Params) {
param, ok := p.Value(1)
require.Equal(t, true, ok)
param := p.Value(1)
require.NotNil(t, param)
require.Equal(t, request.ExecutionFilterT, param.Type)
filt, ok := param.Value.(request.ExecutionFilter)
require.Equal(t, true, ok)

View file

@ -61,12 +61,17 @@ const (
ExecutionFilterT
)
var errMissingParameter = errors.New("parameter is missing")
func (p Param) String() string {
return fmt.Sprintf("%v", p.Value)
}
// GetString returns string value of the parameter.
func (p Param) GetString() (string, error) {
func (p *Param) GetString() (string, error) {
if p == nil {
return "", errMissingParameter
}
str, ok := p.Value.(string)
if !ok {
return "", errors.New("not a string")
@ -74,8 +79,26 @@ func (p Param) GetString() (string, error) {
return str, nil
}
// GetBoolean returns boolean value of the parameter.
func (p *Param) GetBoolean() bool {
if p == nil {
return false
}
switch p.Type {
case NumberT:
return p.Value != 0
case StringT:
return p.Value != ""
default:
return true
}
}
// GetInt returns int value of te parameter.
func (p Param) GetInt() (int, error) {
func (p *Param) GetInt() (int, error) {
if p == nil {
return 0, errMissingParameter
}
i, ok := p.Value.(int)
if ok {
return i, nil
@ -86,7 +109,10 @@ func (p Param) GetInt() (int, error) {
}
// GetArray returns a slice of Params stored in the parameter.
func (p Param) GetArray() ([]Param, error) {
func (p *Param) GetArray() ([]Param, error) {
if p == nil {
return nil, errMissingParameter
}
a, ok := p.Value.([]Param)
if !ok {
return nil, errors.New("not an array")
@ -95,7 +121,7 @@ func (p Param) GetArray() ([]Param, error) {
}
// GetUint256 returns Uint256 value of the parameter.
func (p Param) GetUint256() (util.Uint256, error) {
func (p *Param) GetUint256() (util.Uint256, error) {
s, err := p.GetString()
if err != nil {
return util.Uint256{}, err
@ -105,7 +131,7 @@ func (p Param) GetUint256() (util.Uint256, error) {
}
// GetUint160FromHex returns Uint160 value of the parameter encoded in hex.
func (p Param) GetUint160FromHex() (util.Uint160, error) {
func (p *Param) GetUint160FromHex() (util.Uint160, error) {
s, err := p.GetString()
if err != nil {
return util.Uint160{}, err
@ -119,7 +145,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) {
// GetUint160FromAddress returns Uint160 value of the parameter that was
// supplied as an address.
func (p Param) GetUint160FromAddress() (util.Uint160, error) {
func (p *Param) GetUint160FromAddress() (util.Uint160, error) {
s, err := p.GetString()
if err != nil {
return util.Uint160{}, err
@ -129,7 +155,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) {
}
// GetFuncParam returns current parameter as a function call parameter.
func (p Param) GetFuncParam() (FuncParam, error) {
func (p *Param) GetFuncParam() (FuncParam, error) {
if p == nil {
return FuncParam{}, errMissingParameter
}
fp, ok := p.Value.(FuncParam)
if !ok {
return FuncParam{}, errors.New("not a function parameter")
@ -139,7 +168,7 @@ func (p Param) GetFuncParam() (FuncParam, error) {
// GetBytesHex returns []byte value of the parameter if
// it is a hex-encoded string.
func (p Param) GetBytesHex() ([]byte, error) {
func (p *Param) GetBytesHex() ([]byte, error) {
s, err := p.GetString()
if err != nil {
return nil, err
@ -163,6 +192,11 @@ func (p *Param) UnmarshalJSON(data []byte) error {
{ArrayT, &[]Param{}},
}
if bytes.Equal(data, []byte("null")) {
p.Type = defaultT
return nil
}
for _, cur := range attempts {
r := bytes.NewReader(data)
jd := json.NewDecoder(r)

View file

@ -14,7 +14,7 @@ import (
)
func TestParam_UnmarshalJSON(t *testing.T) {
msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}],
msg := `["str1", 123, null, ["str2", 3], [{"type": "String", "value": "jajaja"}],
{"type": "MinerTransaction"},
{"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"},
{"state": "HALT"}]`
@ -29,6 +29,9 @@ func TestParam_UnmarshalJSON(t *testing.T) {
Type: NumberT,
Value: 123,
},
{
Type: defaultT,
},
{
Type: ArrayT,
Value: []Param{

View file

@ -7,20 +7,19 @@ type (
// Value returns the param struct for the given
// index if it exists.
func (p Params) Value(index int) (*Param, bool) {
func (p Params) Value(index int) *Param {
if len(p) > index {
return &p[index], true
return &p[index]
}
return nil, false
return nil
}
// ValueWithType returns the param struct at the given index if it
// exists and matches the given type.
func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) {
if val, ok := p.Value(index); ok && val.Type == valType {
return val, true
func (p Params) ValueWithType(index int, valType paramType) *Param {
if val := p.Value(index); val != nil && val.Type == valType {
return val
}
return nil, false
return nil
}

View file

@ -0,0 +1,122 @@
package result
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"github.com/nspcc-dev/neo-go/pkg/io"
)
// StateHeight is a result of getstateheight RPC.
type StateHeight struct {
BlockHeight uint32 `json:"blockHeight"`
StateHeight uint32 `json:"stateHeight"`
}
// ProofWithKey represens key-proof pair.
type ProofWithKey struct {
Key []byte
Proof [][]byte
}
// GetProof is a result of getproof RPC.
type GetProof struct {
Result ProofWithKey `json:"proof"`
Success bool `json:"success"`
}
// VerifyProof is a result of verifyproof RPC.
// nil Value is considered invalid.
type VerifyProof struct {
Value []byte
}
// MarshalJSON implements json.Marshaler.
func (p *ProofWithKey) MarshalJSON() ([]byte, error) {
w := io.NewBufBinWriter()
p.EncodeBinary(w.BinWriter)
if w.Err != nil {
return nil, w.Err
}
return []byte(`"` + hex.EncodeToString(w.Bytes()) + `"`), nil
}
// EncodeBinary implements io.Serializable.
func (p *ProofWithKey) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(p.Key)
w.WriteVarUint(uint64(len(p.Proof)))
for i := range p.Proof {
w.WriteVarBytes(p.Proof[i])
}
}
// DecodeBinary implements io.Serializable.
func (p *ProofWithKey) DecodeBinary(r *io.BinReader) {
p.Key = r.ReadVarBytes()
sz := r.ReadVarUint()
for i := uint64(0); i < sz; i++ {
p.Proof = append(p.Proof, r.ReadVarBytes())
}
}
// UnmarshalJSON implements json.Unmarshaler.
func (p *ProofWithKey) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
return p.FromString(s)
}
// String implements fmt.Stringer.
func (p *ProofWithKey) String() string {
w := io.NewBufBinWriter()
p.EncodeBinary(w.BinWriter)
return hex.EncodeToString(w.Bytes())
}
// FromString decodes p from hex-encoded string.
func (p *ProofWithKey) FromString(s string) error {
rawProof, err := hex.DecodeString(s)
if err != nil {
return err
}
r := io.NewBinReaderFromBuf(rawProof)
p.DecodeBinary(r)
return r.Err
}
// MarshalJSON implements json.Marshaler.
func (p *VerifyProof) MarshalJSON() ([]byte, error) {
if p.Value == nil {
return []byte(`"invalid"`), nil
}
return []byte(`{"value":"` + hex.EncodeToString(p.Value) + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (p *VerifyProof) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, []byte(`"invalid"`)) {
p.Value = nil
return nil
}
var m map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return err
}
if len(m) != 1 {
return errors.New("must have single key")
}
v, ok := m["value"]
if !ok {
return errors.New("invalid json")
}
b, err := hex.DecodeString(v)
if err != nil {
return err
}
p.Value = b
return nil
}

View file

@ -0,0 +1,68 @@
package result
import (
"encoding/json"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/stretchr/testify/require"
)
func testProofWithKey() *ProofWithKey {
return &ProofWithKey{
Key: random.Bytes(10),
Proof: [][]byte{
random.Bytes(12),
random.Bytes(0),
random.Bytes(34),
},
}
}
func TestGetProof_MarshalJSON(t *testing.T) {
t.Run("Good", func(t *testing.T) {
p := &GetProof{
Result: *testProofWithKey(),
Success: true,
}
testserdes.MarshalUnmarshalJSON(t, p, new(GetProof))
})
t.Run("Compatibility", func(t *testing.T) {
js := []byte(`{
"proof" : "25ddeb9aa1bfc353c9c54e21dffb470f65d9c22a0662616c616e63654f70000000000000000708fd12020020666eaa8a6e75d43a97d76e72b605c7e05189f0c57ec19d84acdb75810f18239d202c83028ce3d7abcf4e4f95d05fbfdfa5e18bde3a8fbb65a57559d6b5ea09425c2090c40d440744a848e3b407a00e4efb692a957245a1efc9cb8496cb05fd328ee620dd2652bf25dfc3ad5fee7b200ccf3e3ae50772ff8ed58907e4dab8e7d4b2489720d8a5d5ed75b5b0f256d0a2cf5c220b4ddae2a228ef0fc0212b689f3811dfa94620342cc0d73fabd2440ed2cc735a9608391a510e1981b321a9f4258682706adc9620ced036e52f39387b9c58ade7bf8c3ca8959b64d8031d36d9b1c62f3f1c51c7cb2031072c7c801b5c1614dae441383a65344acd238f13db28ff0a39c0626e597f002062552d64c616d8b2a6a93d22936055110c0065728aa2b4fbf4d76b108390b474203322d3c93c741674a307cf6455e77c02ceeda307d4ec23fd809a2a420b4243f82052ab92a9cedc6716ad4c66a8a3e423b195b05bdebde456f992bff48f2561e99720e6379995e7053823b8ba8fb8af9623cf48e89f60c989598445df5e711db42a6f20192894ed637e86561ff6a4b8dea4539dee8bddb2fb20bf4ae3499852985c88b120e0005edd09f2335aa6b59ff4723e1262b2192adaa5e3e56f79e662f07041f04c2033577f3e2c5bb0e58746980a07cdfad2f872e2b9a10bcc27b7c678c85576df8420f0f04180d15b6eaa0c43e62380084c75ad773d790700a7120c6c4da1fc51693000fd720100209648e8f10a5ff4c209009b9a09697babbe1b2150d0948c1970a560282a1bfa4720988af8f34859dd8309bffea0b1dff9c8cef0b9b0d6a1852d40786627729ae7be00206ebf4f1b7861bca041cbb8feca75158511ca43a1810d17e1e3017468e8cef0de20cac93064090a7da09f8202c17d1e6cbb9a16eb43afcb032e80719cbf05b3446d2019b76a10b91fb99ec08814e8108e5490b879fb09a190cb2c129dfd98335bd5de000020b1da1198bacacf2adc0d863929d77c285ce3a26e736203d0c0a69a1312255fb2207ee8aa092f49348bd89f9c4bf004b0bee2241a2d0acfe7b3ce08e414b04a5717205b0dda71eac8a4e4cdc6a7b939748c0a78abb54f2547a780e6df67b25530330f000020fc358fb9d1e0d36461e015ac8e35f97072a9f9e750a3c25722a2b1a858fcb82d203c52c9fac6d4694b351390158334a9166bc3478ceb9bea2b0b244915f918239e20d526344a24ff19ee6a9f5c5beb833f4eb6d51191590350e26fa50b138493473f005200000000000000000000002077c404fec0a4265568951dbd096572787d109fab105213f4f292a5f53ce72fca00000020b8d1c7a386eaba83ce83ee0700d4ca9b86e75d147d670ea05123e438231d895000004801250b090a0a010b0f0c0305030c090c05040e02010d0f0f0b0407000f06050d090c02020a0006202af2097cf9d3f42e49f6b3c3dd254e7cbdab3485b029721cbbbf1ad0455a810852000000000000002055170506f4b18bc573a909b51cb21bdd5d303ec511f6cdfb1c6a1ab8d8a1dad020ee774c1b9fe1d8ea8d05823837d959da48af74f384d52f06c42c9d146c5258e300000000000000000072000000204457a6fe530ee953ad1f9caf63daf7f86719c9986df2d0b6917021eb379800f00020406bfc79da4ba6f37452a679d13cca252585d34f7e94a480b047bad9427f233e00000000201ce15a2373d28e0dc5f2000cf308f155d06f72070a29e5af1528c8f05f29d248000000000000004301200601060c0601060e06030605040f0700000000000000000000000000000000072091b83866bbd7450115b462e8d48601af3c3e9a35e7018d2b98a23e107c15c200090307000410a328e800",
"success" : true
}`)
var p GetProof
require.NoError(t, json.Unmarshal(js, &p))
require.Equal(t, 8, len(p.Result.Proof))
for i := range p.Result.Proof { // smoke test that every chunk is correctly encoded node
r := io.NewBinReaderFromBuf(p.Result.Proof[i])
var n mpt.NodeObject
n.DecodeBinary(r)
require.NoError(t, r.Err)
require.NotNil(t, n.Node)
}
})
}
func TestProofWithKey_EncodeString(t *testing.T) {
expected := testProofWithKey()
var actual ProofWithKey
require.NoError(t, actual.FromString(expected.String()))
require.Equal(t, expected, &actual)
}
func TestVerifyProof_MarshalJSON(t *testing.T) {
t.Run("Good", func(t *testing.T) {
vp := &VerifyProof{random.Bytes(100)}
testserdes.MarshalUnmarshalJSON(t, vp, new(VerifyProof))
})
t.Run("NoValue", func(t *testing.T) {
vp := new(VerifyProof)
testserdes.MarshalUnmarshalJSON(t, vp, &VerifyProof{[]byte{1, 2, 3}})
})
}

View file

@ -15,6 +15,7 @@ import (
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -95,6 +96,9 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon
"getpeers": (*Server).getPeers,
"getrawmempool": (*Server).getRawMempool,
"getrawtransaction": (*Server).getrawtransaction,
"getproof": (*Server).getProof,
"getstateheight": (*Server).getStateHeight,
"getstateroot": (*Server).getStateRoot,
"getstorage": (*Server).getStorage,
"gettransactionheight": (*Server).getTransactionHeight,
"gettxout": (*Server).getTxOut,
@ -108,6 +112,7 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon
"sendrawtransaction": (*Server).sendrawtransaction,
"submitblock": (*Server).submitBlock,
"validateaddress": (*Server).validateAddress,
"verifyproof": (*Server).verifyProof,
}
var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){
@ -389,8 +394,8 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er
func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) {
var hash util.Uint256
param, ok := reqParams.Value(0)
if !ok {
param := reqParams.Value(0)
if param == nil {
return nil, response.ErrInvalidParams
}
@ -416,7 +421,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
return nil, response.NewInternalServerError(fmt.Sprintf("Problem locating block with hash: %s", hash), err)
}
if len(reqParams) == 2 && reqParams[1].Value == 1 {
if reqParams.Value(1).GetBoolean() {
return result.NewBlock(block, s.chain), nil
}
writer := io.NewBufBinWriter()
@ -425,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
}
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.NumberT)
if !ok {
param := reqParams.ValueWithType(0, request.NumberT)
if param == nil {
return nil, response.ErrInvalidParams
}
num, err := s.blockHeightFromParam(param)
@ -463,20 +468,15 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error)
}
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.Value(0)
if !ok {
param := reqParams.Value(0)
if param == nil {
return nil, response.ErrInvalidParams
}
return validateAddress(param.Value), nil
}
func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
paramAssetID, err := param.GetUint256()
paramAssetID, err := reqParams.ValueWithType(0, request.StringT).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -490,12 +490,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response
// getApplicationLog returns the contract log based on the specified txid.
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
txHash, err := param.GetUint256()
txHash, err := reqParams.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -522,10 +517,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp
}
func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
p := ps.ValueWithType(0, request.StringT)
u, err := p.GetUint160FromAddress()
if err != nil {
return nil, response.ErrInvalidParams
@ -574,11 +566,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error)
}
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
u, err := p.GetUint160FromHex()
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -607,11 +595,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro
}
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
u, err := p.GetUint160FromAddress()
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -708,25 +692,96 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
return d, nil
}
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
param, ok := ps.Value(0)
if !ok {
func (s *Server) getProof(ps request.Params) (interface{}, *response.Error) {
root, err := ps.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
sc, err := ps.Value(1).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
sc = sc.Reverse()
key, err := ps.Value(2).GetBytesHex()
if err != nil {
return nil, response.ErrInvalidParams
}
skey := mpt.ToNeoStorageKey(append(sc.BytesBE(), key...))
proof, err := s.chain.GetStateProof(root, skey)
return &result.GetProof{
Result: result.ProofWithKey{
Key: skey,
Proof: proof,
},
Success: err == nil,
}, nil
}
scriptHash, err := param.GetUint160FromHex()
func (s *Server) verifyProof(ps request.Params) (interface{}, *response.Error) {
root, err := ps.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
proofStr, err := ps.Value(1).GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
var p result.ProofWithKey
if err := p.FromString(proofStr); err != nil {
return nil, response.ErrInvalidParams
}
vp := new(result.VerifyProof)
val, ok := mpt.VerifyProof(root, p.Key, p.Proof)
if ok {
var si state.StorageItem
r := io.NewBinReaderFromBuf(val[1:])
si.DecodeBinary(r)
if r.Err != nil {
return nil, response.NewInternalServerError("invalid item in trie", r.Err)
}
vp.Value = si.Value
}
return vp, nil
}
func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) {
return &result.StateHeight{
BlockHeight: s.chain.BlockHeight(),
StateHeight: s.chain.StateHeight(),
}, nil
}
func (s *Server) getStateRoot(ps request.Params) (interface{}, *response.Error) {
p := ps.Value(0)
if p == nil {
return nil, response.NewRPCError("Invalid parameter.", "", nil)
}
var rt *state.MPTRootState
var h util.Uint256
height, err := p.GetInt()
if err == nil {
rt, err = s.chain.GetStateRoot(uint32(height))
} else if h, err = p.GetUint256(); err == nil {
hdr, err := s.chain.GetHeader(h)
if err == nil {
rt, err = s.chain.GetStateRoot(hdr.Index)
}
}
if err != nil {
return nil, response.NewRPCError("Unknown state root.", "", err)
}
return rt, nil
}
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
scriptHash, err := ps.Value(0).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
scriptHash = scriptHash.Reverse()
param, ok = ps.Value(1)
if !ok {
return nil, response.ErrInvalidParams
}
key, err := param.GetBytesHex()
key, err := ps.Value(1).GetBytesHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -743,30 +798,17 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
var resultsErr *response.Error
var results interface{}
if param0, ok := reqParams.Value(0); !ok {
return nil, response.ErrInvalidParams
} else if txHash, err := param0.GetUint256(); err != nil {
if txHash, err := reqParams.Value(0).GetUint256(); err != nil {
resultsErr = response.ErrInvalidParams
} else if tx, height, err := s.chain.GetTransaction(txHash); err != nil {
err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash)
return nil, response.NewRPCError("Unknown transaction", err.Error(), err)
} else if len(reqParams) >= 2 {
} else if reqParams.Value(1).GetBoolean() {
_header := s.chain.GetHeaderHash(int(height))
header, err := s.chain.GetHeader(_header)
if err != nil {
resultsErr = response.NewInvalidParamsError(err.Error(), err)
}
param1, _ := reqParams.Value(1)
switch v := param1.Value.(type) {
case int, float64, bool, string:
if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" {
results = hex.EncodeToString(tx.Bytes())
} else {
results = result.NewTransactionOutputRaw(tx, header, s.chain)
}
default:
} else {
results = result.NewTransactionOutputRaw(tx, header, s.chain)
}
} else {
@ -777,12 +819,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
}
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
h, err := p.GetUint256()
h, err := ps.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -796,22 +833,12 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response
}
func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
h, err := p.GetUint256()
h, err := ps.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
p, ok = ps.ValueWithType(1, request.NumberT)
if !ok {
return nil, response.ErrInvalidParams
}
num, err := p.GetInt()
num, err := ps.ValueWithType(1, request.NumberT).GetInt()
if err != nil || num < 0 {
return nil, response.ErrInvalidParams
}
@ -833,18 +860,15 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) {
func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) {
var results interface{}
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
} else if scriptHash, err := param.GetUint160FromHex(); err != nil {
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
cs := s.chain.GetContractState(scriptHash)
if cs != nil {
results = result.NewContractState(cs)
} else {
cs := s.chain.GetContractState(scriptHash)
if cs != nil {
results = result.NewContractState(cs)
} else {
return nil, response.NewRPCError("Unknown contract", "", nil)
}
return nil, response.NewRPCError("Unknown contract", "", nil)
}
return results, nil
}
@ -862,33 +886,31 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in
var resultsErr *response.Error
var results interface{}
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
} else if scriptHash, err := param.GetUint160FromAddress(); err != nil {
param := reqParams.ValueWithType(0, request.StringT)
scriptHash, err := param.GetUint160FromAddress()
if err != nil {
return nil, response.ErrInvalidParams
}
as := s.chain.GetAccountState(scriptHash)
if as == nil {
as = state.NewAccount(scriptHash)
}
if unspents {
str, err := param.GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
results = result.NewUnspents(as, s.chain, str)
} else {
as := s.chain.GetAccountState(scriptHash)
if as == nil {
as = state.NewAccount(scriptHash)
}
if unspents {
str, err := param.GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
results = result.NewUnspents(as, s.chain, str)
} else {
results = result.NewAccountState(as)
}
results = result.NewAccountState(as)
}
return results, resultsErr
}
// getBlockSysFee returns the system fees of the block, based on the specified index.
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.NumberT)
if !ok {
param := reqParams.ValueWithType(0, request.NumberT)
if param == nil {
return 0, response.ErrInvalidParams
}
@ -913,26 +935,12 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons
// getBlockHeader returns the corresponding block header information according to the specified script hash.
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) {
var verbose bool
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
hash, err := param.GetUint256()
hash, err := reqParams.ValueWithType(0, request.StringT).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
param, ok = reqParams.ValueWithType(1, request.NumberT)
if ok {
v, err := param.GetInt()
if err != nil {
return nil, response.ErrInvalidParams
}
verbose = v != 0
}
verbose := reqParams.Value(1).GetBoolean()
h, err := s.chain.GetHeader(hash)
if err != nil {
return nil, response.NewRPCError("unknown block", "", nil)
@ -952,11 +960,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons
// getUnclaimed returns unclaimed GAS amount of the specified address.
func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
u, err := p.GetUint160FromAddress()
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -997,19 +1001,11 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error)
// invoke implements the `invoke` RPC call.
func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) {
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
scriptHash, err := scriptHashHex.GetUint160FromHex()
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
sliceP, ok := reqParams.ValueWithType(1, request.ArrayT)
if !ok {
return nil, response.ErrInvalidParams
}
slice, err := sliceP.GetArray()
slice, err := reqParams.ValueWithType(1, request.ArrayT).GetArray()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1022,11 +1018,7 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error)
// invokescript implements the `invokescript` RPC call.
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) {
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
scriptHash, err := scriptHashHex.GetUint160FromHex()
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1069,11 +1061,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke {
// submitBlock broadcasts a raw block over the NEO network.
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
blockBytes, err := param.GetBytesHex()
blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1134,11 +1122,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res
// subscribe handles subscription requests from websocket clients.
func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
p, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
streamName, err := p.GetString()
streamName, err := reqParams.Value(0).GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1148,8 +1132,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
}
// Optional filter.
var filter interface{}
p, ok = reqParams.Value(1)
if ok {
if p := reqParams.Value(1); p != nil {
// It doesn't accept filters.
if event == response.BlockEventID {
return nil, response.ErrInvalidParams
@ -1224,11 +1207,7 @@ func (s *Server) subscribeToChannel(event response.EventID) {
// unsubscribe handles unsubscription requests from websocket clients.
func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
p, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
id, err := p.GetInt()
id, err := reqParams.Value(0).GetInt()
if err != nil || id < 0 {
return nil, response.ErrInvalidParams
}

View file

@ -16,6 +16,8 @@ import (
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
@ -213,6 +215,54 @@ var rpcTestCases = map[string][]rpcTestCase{
},
},
},
"getproof": {
{
name: "no params",
params: `[]`,
fail: true,
},
{
name: "invalid root",
params: `["0xabcdef"]`,
fail: true,
},
{
name: "invalid contract",
params: `["0000000000000000000000000000000000000000000000000000000000000000", "0xabcdef"]`,
fail: true,
},
{
name: "invalid key",
params: `["0000000000000000000000000000000000000000000000000000000000000000", "` + testContractHash + `", "notahex"]`,
fail: true,
},
},
"getstateheight": {
{
name: "positive",
params: `[]`,
result: func(_ *executor) interface{} { return new(result.StateHeight) },
check: func(t *testing.T, e *executor, res interface{}) {
sh, ok := res.(*result.StateHeight)
require.True(t, ok)
require.Equal(t, e.chain.BlockHeight(), sh.BlockHeight)
require.Equal(t, e.chain.StateHeight(), sh.StateHeight)
},
},
},
"getstateroot": {
{
name: "no params",
params: `[]`,
fail: true,
},
{
name: "invalid hash",
params: `["0x1234567890"]`,
fail: true,
},
},
"getstorage": {
{
name: "positive",
@ -928,6 +978,52 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []
})
}
t.Run("getproof", func(t *testing.T) {
r, err := chain.GetStateRoot(205)
require.NoError(t, err)
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getproof", "params": ["%s", "%s", "%x"]}`,
r.Root.StringLE(), testContractHash, []byte("testkey"))
fmt.Println(rpc)
body := doRPCCall(rpc, httpSrv.URL, t)
fmt.Println(string(body))
rawRes := checkErrGetResult(t, body, false)
res := new(result.GetProof)
require.NoError(t, json.Unmarshal(rawRes, res))
require.True(t, res.Success)
h, _ := hex.DecodeString(testContractHash)
skey := append(h, []byte("testkey")...)
require.Equal(t, mpt.ToNeoStorageKey(skey), res.Result.Key)
require.True(t, len(res.Result.Proof) > 0)
rpc = fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "verifyproof", "params": ["%s", "%s"]}`,
r.Root.StringLE(), res.Result.String())
body = doRPCCall(rpc, httpSrv.URL, t)
rawRes = checkErrGetResult(t, body, false)
vp := new(result.VerifyProof)
require.NoError(t, json.Unmarshal(rawRes, vp))
require.Equal(t, []byte("testvalue"), vp.Value)
})
t.Run("getstateroot", func(t *testing.T) {
testRoot := func(t *testing.T, p string) {
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getstateroot", "params": [%s]}`, p)
fmt.Println(rpc)
body := doRPCCall(rpc, httpSrv.URL, t)
rawRes := checkErrGetResult(t, body, false)
res := new(state.MPTRootState)
require.NoError(t, json.Unmarshal(rawRes, res))
require.NotEqual(t, util.Uint256{}, res.Root) // be sure this test uses valid height
expected, err := e.chain.GetStateRoot(205)
require.NoError(t, err)
require.Equal(t, expected, res)
}
t.Run("ByHeight", func(t *testing.T) { testRoot(t, strconv.FormatInt(205, 10)) })
t.Run("ByHash", func(t *testing.T) { testRoot(t, `"`+chain.GetHeaderHash(205).StringLE()+`"`) })
})
t.Run("getrawtransaction", func(t *testing.T) {
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
TXHash := block.Transactions[1].Hash()