Merge pull request #1340 from nspcc-dev/compiler/shortjumps
Emit short jumps where possible
This commit is contained in:
commit
790693fc6d
5 changed files with 264 additions and 22 deletions
|
@ -691,25 +691,21 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
|||
case *ast.BinaryExpr:
|
||||
switch n.Op {
|
||||
case token.LAND:
|
||||
next := c.newLabel()
|
||||
end := c.newLabel()
|
||||
ast.Walk(c, n.X)
|
||||
emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, next)
|
||||
emit.Instruction(c.prog.BinWriter, opcode.JMPIF, []byte{2 + 1 + 5})
|
||||
emit.Opcode(c.prog.BinWriter, opcode.PUSHF)
|
||||
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
|
||||
c.setLabel(next)
|
||||
ast.Walk(c, n.Y)
|
||||
c.setLabel(end)
|
||||
return nil
|
||||
|
||||
case token.LOR:
|
||||
next := c.newLabel()
|
||||
end := c.newLabel()
|
||||
ast.Walk(c, n.X)
|
||||
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, next)
|
||||
emit.Instruction(c.prog.BinWriter, opcode.JMPIFNOT, []byte{2 + 1 + 5})
|
||||
emit.Opcode(c.prog.BinWriter, opcode.PUSHT)
|
||||
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
|
||||
c.setLabel(next)
|
||||
ast.Walk(c, n.Y)
|
||||
c.setLabel(end)
|
||||
return nil
|
||||
|
@ -1613,8 +1609,8 @@ func CodeGen(info *buildInfo) ([]byte, *DebugInfo, error) {
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
buf := c.prog.Bytes()
|
||||
if err := c.writeJumps(buf); err != nil {
|
||||
buf, err := c.writeJumps(c.prog.Bytes())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return buf, c.emitDebugInfo(buf), nil
|
||||
|
@ -1629,15 +1625,14 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *codegen) writeJumps(b []byte) error {
|
||||
func (c *codegen) writeJumps(b []byte) ([]byte, error) {
|
||||
ctx := vm.NewContext(b)
|
||||
for op, _, err := ctx.Next(); err == nil && ctx.NextIP() < len(b); op, _, err = ctx.Next() {
|
||||
var offsets []int
|
||||
for op, _, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, _, err = ctx.Next() {
|
||||
switch op {
|
||||
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
|
||||
opcode.JMPEQ, opcode.JMPNE,
|
||||
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
|
||||
// Noop, assumed to be correct already. If you're fixing #905,
|
||||
// make sure not to break "len" and "append" handling above.
|
||||
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
|
||||
opcode.JMPEQL, opcode.JMPNEL,
|
||||
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
|
||||
|
@ -1648,15 +1643,135 @@ func (c *codegen) writeJumps(b []byte) error {
|
|||
|
||||
index := binary.LittleEndian.Uint16(arg)
|
||||
if int(index) > len(c.l) {
|
||||
return fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l))
|
||||
return nil, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l))
|
||||
}
|
||||
offset := c.l[index] - nextIP + 5
|
||||
if offset > math.MaxInt32 || offset < math.MinInt32 {
|
||||
return fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)",
|
||||
return nil, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)",
|
||||
nextIP-5, offset, math.MaxInt32, math.MinInt32)
|
||||
}
|
||||
if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 {
|
||||
offsets = append(offsets, ctx.IP())
|
||||
}
|
||||
binary.LittleEndian.PutUint32(arg, uint32(offset))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
// Correct function ip range.
|
||||
// Note: indices are sorted in increasing order.
|
||||
for _, f := range c.funcs {
|
||||
loop:
|
||||
for _, ind := range offsets {
|
||||
switch {
|
||||
case ind > int(f.rng.End):
|
||||
break loop
|
||||
case ind < int(f.rng.Start):
|
||||
f.rng.Start -= longToShortRemoveCount
|
||||
f.rng.End -= longToShortRemoveCount
|
||||
case ind >= int(f.rng.Start):
|
||||
f.rng.End -= longToShortRemoveCount
|
||||
}
|
||||
}
|
||||
}
|
||||
return shortenJumps(b, offsets), nil
|
||||
}
|
||||
|
||||
// longToShortRemoveCount is a difference between short and long instruction sizes in bytes.
|
||||
const longToShortRemoveCount = 3
|
||||
|
||||
// shortenJumps returns converts b to a program where all long JMP*/CALL* specified by absolute offsets,
|
||||
// are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid.
|
||||
// This is done in 2 passes:
|
||||
// 1. Alter jump offsets taking into account parts to be removed.
|
||||
// 2. Perform actual removal of jump targets.
|
||||
// Note: after jump offsets altering, there can appear new candidates for conversion.
|
||||
// These are ignored for now.
|
||||
func shortenJumps(b []byte, offsets []int) []byte {
|
||||
if len(offsets) == 0 {
|
||||
return b
|
||||
}
|
||||
|
||||
// 1. Alter existing jump offsets.
|
||||
ctx := vm.NewContext(b)
|
||||
for op, _, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, _, err = ctx.Next() {
|
||||
// we can't use arg returned by ctx.Next() because it is copied
|
||||
nextIP := ctx.NextIP()
|
||||
ip := ctx.IP()
|
||||
switch op {
|
||||
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
|
||||
opcode.JMPEQ, opcode.JMPNE,
|
||||
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
|
||||
offset := int(int8(b[nextIP-1]))
|
||||
offset += calcOffsetCorrection(ip, ip+offset, offsets)
|
||||
b[nextIP-1] = byte(offset)
|
||||
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
|
||||
opcode.JMPEQL, opcode.JMPNEL,
|
||||
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
|
||||
opcode.CALLL, opcode.PUSHA:
|
||||
arg := b[nextIP-4:]
|
||||
offset := int(int32(binary.LittleEndian.Uint32(arg)))
|
||||
offset += calcOffsetCorrection(ip, ip+offset, offsets)
|
||||
binary.LittleEndian.PutUint32(arg, uint32(offset))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Convert instructions.
|
||||
copyOffset := 0
|
||||
l := len(offsets)
|
||||
b[offsets[0]] = toShortForm(b[offsets[0]])
|
||||
for i := 0; i < l; i++ {
|
||||
start := offsets[i] + 2
|
||||
end := len(b)
|
||||
if i != l-1 {
|
||||
end = offsets[i+1] + 2
|
||||
b[offsets[i+1]] = toShortForm(b[offsets[i+1]])
|
||||
}
|
||||
copy(b[start-copyOffset:], b[start+3:end])
|
||||
copyOffset += longToShortRemoveCount
|
||||
}
|
||||
return b[:len(b)-copyOffset]
|
||||
}
|
||||
|
||||
func calcOffsetCorrection(ip, target int, offsets []int) int {
|
||||
cnt := 0
|
||||
start := sort.Search(len(offsets), func(i int) bool {
|
||||
return offsets[i] >= ip || offsets[i] >= target
|
||||
})
|
||||
for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ {
|
||||
ind := offsets[i]
|
||||
if ip <= ind && ind < target ||
|
||||
ind != ip && target <= ind && ind <= ip {
|
||||
cnt += longToShortRemoveCount
|
||||
}
|
||||
}
|
||||
if ip < target {
|
||||
return -cnt
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
|
||||
func toShortForm(b byte) byte {
|
||||
switch op := opcode.Opcode(b); op {
|
||||
case opcode.JMPL:
|
||||
return byte(opcode.JMP)
|
||||
case opcode.JMPIFL:
|
||||
return byte(opcode.JMPIF)
|
||||
case opcode.JMPIFNOTL:
|
||||
return byte(opcode.JMPIFNOT)
|
||||
case opcode.JMPEQL:
|
||||
return byte(opcode.JMPEQ)
|
||||
case opcode.JMPNEL:
|
||||
return byte(opcode.JMPNE)
|
||||
case opcode.JMPGTL:
|
||||
return byte(opcode.JMPGT)
|
||||
case opcode.JMPGEL:
|
||||
return byte(opcode.JMPGE)
|
||||
case opcode.JMPLEL:
|
||||
return byte(opcode.JMPLE)
|
||||
case opcode.JMPLTL:
|
||||
return byte(opcode.JMPLT)
|
||||
case opcode.CALLL:
|
||||
return byte(opcode.CALL)
|
||||
default:
|
||||
panic(fmt.Errorf("invalid opcode: %s", op))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -221,7 +221,7 @@ func TestBuiltinDoesNotCompile(t *testing.T) {
|
|||
ctx := v.Context()
|
||||
retCount := 0
|
||||
for op, _, err := ctx.Next(); err == nil; op, _, err = ctx.Next() {
|
||||
if ctx.IP() > len(ctx.Program()) {
|
||||
if ctx.IP() >= len(ctx.Program()) {
|
||||
break
|
||||
}
|
||||
if op == opcode.RET {
|
||||
|
|
129
pkg/compiler/jumps_test.go
Normal file
129
pkg/compiler/jumps_test.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package compiler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testShortenJumps(t *testing.T, before, after []opcode.Opcode, indices []int) {
|
||||
prog := make([]byte, len(before))
|
||||
for i := range before {
|
||||
prog[i] = byte(before[i])
|
||||
}
|
||||
raw := shortenJumps(prog, indices)
|
||||
actual := make([]opcode.Opcode, len(raw))
|
||||
for i := range raw {
|
||||
actual[i] = opcode.Opcode(raw[i])
|
||||
}
|
||||
require.Equal(t, after, actual)
|
||||
}
|
||||
|
||||
func TestShortenJumps(t *testing.T) {
|
||||
testCases := map[opcode.Opcode]opcode.Opcode{
|
||||
opcode.JMPL: opcode.JMP,
|
||||
opcode.JMPIFL: opcode.JMPIF,
|
||||
opcode.JMPIFNOTL: opcode.JMPIFNOT,
|
||||
opcode.JMPEQL: opcode.JMPEQ,
|
||||
opcode.JMPNEL: opcode.JMPNE,
|
||||
opcode.JMPGTL: opcode.JMPGT,
|
||||
opcode.JMPGEL: opcode.JMPGE,
|
||||
opcode.JMPLEL: opcode.JMPLE,
|
||||
opcode.JMPLTL: opcode.JMPLT,
|
||||
opcode.CALLL: opcode.CALL,
|
||||
}
|
||||
for op, sop := range testCases {
|
||||
t.Run(op.String(), func(t *testing.T) {
|
||||
before := []opcode.Opcode{
|
||||
op, 6, 0, 0, 0, opcode.PUSH1, opcode.NOP, // <- first jump to here
|
||||
op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here
|
||||
op, 255, 0, 0, 0, op, 0xFF - 5, 0xFF, 0xFF, 0xFF,
|
||||
}
|
||||
after := []opcode.Opcode{
|
||||
sop, 3, opcode.PUSH1, opcode.NOP,
|
||||
op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP,
|
||||
sop, 249, sop, 0xFF - 2,
|
||||
}
|
||||
testShortenJumps(t, before, after, []int{0, 14, 19})
|
||||
})
|
||||
}
|
||||
t.Run("NoReplace", func(t *testing.T) {
|
||||
b := []byte{0, 1, 2, 3, 4, 5}
|
||||
expected := []byte{0, 1, 2, 3, 4, 5}
|
||||
require.Equal(t, expected, shortenJumps(b, nil))
|
||||
})
|
||||
t.Run("InvalidIndex", func(t *testing.T) {
|
||||
before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0}
|
||||
require.Panics(t, func() {
|
||||
shortenJumps(before, []int{0})
|
||||
})
|
||||
})
|
||||
t.Run("SideConditions", func(t *testing.T) {
|
||||
t.Run("Forward", func(t *testing.T) {
|
||||
before := []opcode.Opcode{
|
||||
opcode.JMPL, 5, 0, 0, 0,
|
||||
opcode.JMPL, 5, 0, 0, 0,
|
||||
}
|
||||
after := []opcode.Opcode{
|
||||
opcode.JMP, 2,
|
||||
opcode.JMP, 2,
|
||||
}
|
||||
testShortenJumps(t, before, after, []int{0, 5})
|
||||
})
|
||||
t.Run("Backwards", func(t *testing.T) {
|
||||
before := []opcode.Opcode{
|
||||
opcode.JMPL, 5, 0, 0, 0,
|
||||
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF,
|
||||
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF,
|
||||
}
|
||||
after := []opcode.Opcode{
|
||||
opcode.JMPL, 5, 0, 0, 0,
|
||||
opcode.JMP, 0xFF - 4,
|
||||
opcode.JMP, 0xFF - 1,
|
||||
}
|
||||
testShortenJumps(t, before, after, []int{5, 10})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteJumps(t *testing.T) {
|
||||
c := new(codegen)
|
||||
c.l = []int{10}
|
||||
before := []byte{
|
||||
byte(opcode.NOP), byte(opcode.JMP), 2, byte(opcode.RET),
|
||||
byte(opcode.CALLL), 0, 0, 0, 0, byte(opcode.RET),
|
||||
byte(opcode.PUSH2), byte(opcode.RET),
|
||||
}
|
||||
c.funcs = map[string]*funcScope{
|
||||
"init": {rng: DebugRange{Start: 0, End: 3}},
|
||||
"main": {rng: DebugRange{Start: 4, End: 9}},
|
||||
"method": {rng: DebugRange{Start: 10, End: 11}},
|
||||
}
|
||||
|
||||
expProg := []byte{
|
||||
byte(opcode.NOP), byte(opcode.JMP), 2, byte(opcode.RET),
|
||||
byte(opcode.CALL), 3, byte(opcode.RET),
|
||||
byte(opcode.PUSH2), byte(opcode.RET),
|
||||
}
|
||||
expFuncs := map[string]*funcScope{
|
||||
"init": {rng: DebugRange{Start: 0, End: 3}},
|
||||
"main": {rng: DebugRange{Start: 4, End: 6}},
|
||||
"method": {rng: DebugRange{Start: 7, End: 8}},
|
||||
}
|
||||
|
||||
buf, err := c.writeJumps(before)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expProg, buf)
|
||||
require.Equal(t, expFuncs, c.funcs)
|
||||
}
|
||||
|
||||
func TestWriteJumpsLastJump(t *testing.T) {
|
||||
c := new(codegen)
|
||||
c.l = []int{2}
|
||||
prog := []byte{byte(opcode.JMP), 3, byte(opcode.RET), byte(opcode.JMPL), 0, 0, 0, 0}
|
||||
expected := []byte{byte(opcode.JMP), 3, byte(opcode.RET), byte(opcode.JMP), 0xFF}
|
||||
actual, err := c.writeJumps(prog)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
|
@ -433,8 +433,8 @@ func handleOps(c *ishell.Context) {
|
|||
}
|
||||
|
||||
func changePrompt(c ishell.Actions, v *vm.VM) {
|
||||
if v.Ready() && v.Context().IP()-1 >= 0 {
|
||||
c.SetPrompt(fmt.Sprintf("NEO-GO-VM %d > ", v.Context().IP()-1))
|
||||
if v.Ready() && v.Context().IP() >= 0 {
|
||||
c.SetPrompt(fmt.Sprintf("NEO-GO-VM %d > ", v.Context().IP()))
|
||||
} else {
|
||||
c.SetPrompt("NEO-GO-VM > ")
|
||||
}
|
||||
|
|
|
@ -141,11 +141,9 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) {
|
|||
return instr, parameter, nil
|
||||
}
|
||||
|
||||
// IP returns the absolute instruction without taking 0 into account.
|
||||
// If that program starts the ip = 0 but IP() will return 1, cause its
|
||||
// the first instruction.
|
||||
// IP returns current instruction offset in the context script.
|
||||
func (c *Context) IP() int {
|
||||
return c.ip + 1
|
||||
return c.ip
|
||||
}
|
||||
|
||||
// LenInstr returns the number of instructions loaded.
|
||||
|
|
Loading…
Reference in a new issue