Merge pull request #1241 from nspcc-dev/fix/string

Ensure strings are valid UTF-8 where appropriate
This commit is contained in:
Roman Khimov 2020-08-03 18:10:27 +03:00 committed by GitHub
commit 00671deb8f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 68 additions and 46 deletions

View file

@ -62,7 +62,7 @@ func getLogHandler(logs *[]string) vm.SyscallHandler {
return errors.New("syscall not found") return errors.New("syscall not found")
} }
msg := string(v.Estack().Pop().Bytes()) msg := v.Estack().Pop().String()
*logs = append(*logs, msg) *logs = append(*logs, msg)
return nil return nil
} }

View file

@ -129,7 +129,7 @@ func (s *storagePlugin) syscallHandler(v *vm.VM, id uint32) error {
} }
func (s *storagePlugin) Notify(v *vm.VM) error { func (s *storagePlugin) Notify(v *vm.VM) error {
name := string(v.Estack().Pop().Bytes()) name := v.Estack().Pop().String()
item := stackitem.NewArray(v.Estack().Pop().Array()) item := stackitem.NewArray(v.Estack().Pop().Array())
s.events = append(s.events, state.NotificationEvent{ s.events = append(s.events, state.NotificationEvent{
Name: name, Name: name,

View file

@ -205,7 +205,7 @@ func runtimeEncode(_ *interop.Context, v *vm.VM) error {
// runtimeDecode decodes top stack item from base64 string to byte array. // runtimeDecode decodes top stack item from base64 string to byte array.
func runtimeDecode(_ *interop.Context, v *vm.VM) error { func runtimeDecode(_ *interop.Context, v *vm.VM) error {
src := string(v.Estack().Pop().Bytes()) src := v.Estack().Pop().String()
result, err := base64.StdEncoding.DecodeString(src) result, err := base64.StdEncoding.DecodeString(src)
if err != nil { if err != nil {
return err return err

View file

@ -7,7 +7,6 @@ import (
"math" "math"
"math/big" "math/big"
"strings" "strings"
"unicode/utf8"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
@ -265,13 +264,10 @@ func runtimeGetTrigger(ic *interop.Context, v *vm.VM) error {
// runtimeNotify should pass stack item to the notify plugin to handle it, but // runtimeNotify should pass stack item to the notify plugin to handle it, but
// in neo-go the only meaningful thing to do here is to log. // in neo-go the only meaningful thing to do here is to log.
func runtimeNotify(ic *interop.Context, v *vm.VM) error { func runtimeNotify(ic *interop.Context, v *vm.VM) error {
name := v.Estack().Pop().Bytes() name := v.Estack().Pop().String()
if len(name) > MaxEventNameLen { if len(name) > MaxEventNameLen {
return fmt.Errorf("event name must be less than %d", MaxEventNameLen) return fmt.Errorf("event name must be less than %d", MaxEventNameLen)
} }
if !utf8.Valid(name) {
return errors.New("event name should be UTF8-encoded")
}
elem := v.Estack().Pop() elem := v.Estack().Pop()
args := elem.Array() args := elem.Array()
// But it has to be serializable, otherwise we either have some broken // But it has to be serializable, otherwise we either have some broken
@ -285,7 +281,7 @@ func runtimeNotify(ic *interop.Context, v *vm.VM) error {
} }
ne := state.NotificationEvent{ ne := state.NotificationEvent{
ScriptHash: v.GetCurrentScriptHash(), ScriptHash: v.GetCurrentScriptHash(),
Name: string(name), Name: name,
Item: stackitem.NewArray(args), Item: stackitem.NewArray(args),
} }
ic.Notifications = append(ic.Notifications, ne) ic.Notifications = append(ic.Notifications, ne)
@ -294,13 +290,10 @@ func runtimeNotify(ic *interop.Context, v *vm.VM) error {
// runtimeLog logs the message passed. // runtimeLog logs the message passed.
func runtimeLog(ic *interop.Context, v *vm.VM) error { func runtimeLog(ic *interop.Context, v *vm.VM) error {
state := v.Estack().Pop().Bytes() state := v.Estack().Pop().String()
if len(state) > MaxNotificationSize { if len(state) > MaxNotificationSize {
return fmt.Errorf("message length shouldn't exceed %v", MaxNotificationSize) return fmt.Errorf("message length shouldn't exceed %v", MaxNotificationSize)
} }
if !utf8.Valid(state) {
return errors.New("log message should be UTF8-encoded")
}
msg := fmt.Sprintf("%q", state) msg := fmt.Sprintf("%q", state)
ic.Log.Info("runtime log", ic.Log.Info("runtime log",
zap.Stringer("script", v.GetCurrentScriptHash()), zap.Stringer("script", v.GetCurrentScriptHash()),
@ -464,16 +457,16 @@ func storageContextAsReadOnly(ic *interop.Context, v *vm.VM) error {
// contractCall calls a contract. // contractCall calls a contract.
func contractCall(ic *interop.Context, v *vm.VM) error { func contractCall(ic *interop.Context, v *vm.VM) error {
h := v.Estack().Pop().Bytes() h := v.Estack().Pop().Bytes()
method := v.Estack().Pop().Item() method := v.Estack().Pop().String()
args := v.Estack().Pop().Item() args := v.Estack().Pop().Array()
return contractCallExInternal(ic, v, h, method, args, smartcontract.All) return contractCallExInternal(ic, v, h, method, args, smartcontract.All)
} }
// contractCallEx calls a contract with flags. // contractCallEx calls a contract with flags.
func contractCallEx(ic *interop.Context, v *vm.VM) error { func contractCallEx(ic *interop.Context, v *vm.VM) error {
h := v.Estack().Pop().Bytes() h := v.Estack().Pop().Bytes()
method := v.Estack().Pop().Item() method := v.Estack().Pop().String()
args := v.Estack().Pop().Item() args := v.Estack().Pop().Array()
flags := smartcontract.CallFlag(int32(v.Estack().Pop().BigInt().Int64())) flags := smartcontract.CallFlag(int32(v.Estack().Pop().BigInt().Int64()))
if flags&^smartcontract.All != 0 { if flags&^smartcontract.All != 0 {
return errors.New("call flags out of range") return errors.New("call flags out of range")
@ -481,7 +474,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error {
return contractCallExInternal(ic, v, h, method, args, flags) return contractCallExInternal(ic, v, h, method, args, flags)
} }
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error { func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, name string, args []stackitem.Item, f smartcontract.CallFlag) error {
u, err := util.Uint160DecodeBytesBE(h) u, err := util.Uint160DecodeBytesBE(h)
if err != nil { if err != nil {
return errors.New("invalid contract hash") return errors.New("invalid contract hash")
@ -490,11 +483,6 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
if err != nil { if err != nil {
return errors.New("contract not found") return errors.New("contract not found")
} }
bs, err := method.TryBytes()
if err != nil {
return err
}
name := string(bs)
if strings.HasPrefix(name, "_") { if strings.HasPrefix(name, "_") {
return errors.New("invalid method name (starts with '_')") return errors.New("invalid method name (starts with '_')")
} }
@ -504,17 +492,13 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
} }
curr, err := ic.DAO.GetContractState(v.GetCurrentScriptHash()) curr, err := ic.DAO.GetContractState(v.GetCurrentScriptHash())
if err == nil { if err == nil {
if !curr.Manifest.CanCall(&cs.Manifest, string(bs)) { if !curr.Manifest.CanCall(&cs.Manifest, name) {
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
arr, ok := args.Value().([]stackitem.Item) if len(args) != len(md.Parameters) {
if !ok { return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters))
return errors.New("second argument must be an array")
}
if len(arr) != len(md.Parameters) {
return fmt.Errorf("invalid argument count: %d (expected %d)", len(arr), len(md.Parameters))
} }
ic.Invocations[u]++ ic.Invocations[u]++
@ -528,10 +512,10 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
} }
if isNative { if isNative {
v.Estack().PushVal(args) v.Estack().PushVal(args)
v.Estack().PushVal(method) v.Estack().PushVal(name)
} else { } else {
for i := len(arr) - 1; i >= 0; i-- { for i := len(args) - 1; i >= 0; i-- {
v.Estack().PushVal(arr[i]) v.Estack().PushVal(args[i])
} }
// use Jump not Call here because context was loaded in LoadScript above. // use Jump not Call here because context was loaded in LoadScript above.
v.Jump(v.Context(), md.Offset) v.Jump(v.Context(), md.Offset)

View file

@ -264,9 +264,9 @@ func TestRuntimeGetNotifications(t *testing.T) {
for i := range arr { for i := range arr {
elem := arr[i].Value().([]stackitem.Item) elem := arr[i].Value().([]stackitem.Item)
require.Equal(t, ic.Notifications[i].ScriptHash.BytesBE(), elem[0].Value()) require.Equal(t, ic.Notifications[i].ScriptHash.BytesBE(), elem[0].Value())
name, err := elem[1].TryBytes() name, err := stackitem.ToString(elem[1])
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ic.Notifications[i].Name, string(name)) require.Equal(t, ic.Notifications[i].Name, name)
require.Equal(t, ic.Notifications[i].Item, elem[2]) require.Equal(t, ic.Notifications[i].Item, elem[2])
} }
}) })
@ -280,9 +280,9 @@ func TestRuntimeGetNotifications(t *testing.T) {
require.Equal(t, 1, len(arr)) require.Equal(t, 1, len(arr))
elem := arr[0].Value().([]stackitem.Item) elem := arr[0].Value().([]stackitem.Item)
require.Equal(t, h, elem[0].Value()) require.Equal(t, h, elem[0].Value())
name, err := elem[1].TryBytes() name, err := stackitem.ToString(elem[1])
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ic.Notifications[1].Name, string(name)) require.Equal(t, ic.Notifications[1].Name, name)
require.Equal(t, ic.Notifications[1].Item, elem[2]) require.Equal(t, ic.Notifications[1].Item, elem[2])
}) })
} }
@ -443,7 +443,14 @@ func TestContractCall(t *testing.T) {
for i := range args { for i := range args {
v.Estack().PushVal(args[i]) v.Estack().PushVal(args[i])
} }
require.Error(t, contractCall(ic, v)) // interops can both return error and panic,
// we don't care which kind of error has occured
require.Panics(t, func() {
err := contractCall(ic, v)
if err != nil {
panic(err)
}
})
} }
} }

View file

@ -35,7 +35,7 @@ func Deploy(ic *interop.Context, _ *vm.VM) error {
// Call calls specified native contract method. // Call calls specified native contract method.
func Call(ic *interop.Context, v *vm.VM) error { func Call(ic *interop.Context, v *vm.VM) error {
name := string(v.Estack().Pop().Bytes()) name := v.Estack().Pop().String()
var c interop.Contract var c interop.Contract
for _, ctr := range ic.Natives { for _, ctr := range ic.Natives {
if ctr.Metadata().Name == name { if ctr.Metadata().Name == name {
@ -50,7 +50,7 @@ func Call(ic *interop.Context, v *vm.VM) error {
if !h.Equals(c.Metadata().Hash) { if !h.Equals(c.Metadata().Hash) {
return errors.New("it is not allowed to use Neo.Native.Call directly to call native contracts. System.Contract.Call should be used") return errors.New("it is not allowed to use Neo.Native.Call directly to call native contracts. System.Contract.Call should be used")
} }
operation := string(v.Estack().Pop().Bytes()) operation := v.Estack().Pop().String()
args := v.Estack().Pop().Array() args := v.Estack().Pop().Array()
m, ok := c.Metadata().Methods[operation] m, ok := c.Metadata().Methods[operation]
if !ok { if !ok {

View file

@ -67,16 +67,16 @@ func defaultSyscallHandler(v *VM, id uint32) error {
// runtimeLog handles the syscall "System.Runtime.Log" for printing and logging stuff. // runtimeLog handles the syscall "System.Runtime.Log" for printing and logging stuff.
func runtimeLog(vm *VM) error { func runtimeLog(vm *VM) error {
item := vm.Estack().Pop() msg := vm.Estack().Pop().String()
fmt.Printf("NEO-GO-VM (log) > %s\n", item.Value()) fmt.Printf("NEO-GO-VM (log) > %s\n", msg)
return nil return nil
} }
// runtimeNotify handles the syscall "System.Runtime.Notify" for printing and logging stuff. // runtimeNotify handles the syscall "System.Runtime.Notify" for printing and logging stuff.
func runtimeNotify(vm *VM) error { func runtimeNotify(vm *VM) error {
name := vm.Estack().Pop().Bytes() name := vm.Estack().Pop().String()
item := vm.Estack().Pop() item := vm.Estack().Pop()
fmt.Printf("NEO-GO-VM (notify) > [%s] %s\n", string(name), item.Value()) fmt.Printf("NEO-GO-VM (notify) > [%s] %s\n", name, item.Value())
return nil return nil
} }

View file

@ -96,6 +96,16 @@ func (e *Element) Bytes() []byte {
return bs return bs
} }
// String attempts to get string from the element value.
// It is assumed to be use in interops and panics if string is not a valid UTF-8 byte sequence.
func (e *Element) String() string {
s, err := stackitem.ToString(e.value)
if err != nil {
panic(err)
}
return s
}
// Array attempts to get the underlying value of the element as an array of // Array attempts to get the underlying value of the element as an array of
// other items. Will panic if the item type is different which will be caught // other items. Will panic if the item type is different which will be caught
// by the VM. // by the VM.

View file

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"reflect" "reflect"
"unicode/utf8"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
@ -120,6 +121,18 @@ func Make(v interface{}) Item {
} }
} }
// ToString converts Item to string if it is a valid UTF-8.
func ToString(item Item) (string, error) {
bs, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(bs) {
return "", errors.New("not a valid UTF-8")
}
return string(bs), nil
}
// convertPrimitive converts primitive item to a specified type. // convertPrimitive converts primitive item to a specified type.
func convertPrimitive(item Item, typ Type) (Item, error) { func convertPrimitive(item Item, typ Type) (Item, error) {
if item.Type() == typ { if item.Type() == typ {

View file

@ -64,9 +64,17 @@ func toJSON(buf *io.BufBinWriter, item Item) {
case *Map: case *Map:
w.WriteB('{') w.WriteB('{')
for i := range it.value { for i := range it.value {
bs, _ := it.value[i].Key.TryBytes() // map key can always be converted to []byte // map key can always be converted to []byte
// but are not always a valid UTF-8.
key, err := ToString(it.value[i].Key)
if err != nil {
if buf.Err == nil {
buf.Err = err
}
return
}
w.WriteB('"') w.WriteB('"')
w.WriteBytes(bs) w.WriteBytes([]byte(key))
w.WriteBytes([]byte(`":`)) w.WriteBytes([]byte(`":`))
toJSON(buf, it.value[i].Value) toJSON(buf, it.value[i].Value)
if i < len(it.value)-1 { if i < len(it.value)-1 {