Merge pull request #1340 from nspcc-dev/compiler/shortjumps

Emit short jumps where possible
This commit is contained in:
Roman Khimov 2020-08-21 10:20:22 +03:00 committed by GitHub
commit 790693fc6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 264 additions and 22 deletions

View file

@ -691,25 +691,21 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
case *ast.BinaryExpr: case *ast.BinaryExpr:
switch n.Op { switch n.Op {
case token.LAND: case token.LAND:
next := c.newLabel()
end := c.newLabel() end := c.newLabel()
ast.Walk(c, n.X) ast.Walk(c, n.X)
emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, next) emit.Instruction(c.prog.BinWriter, opcode.JMPIF, []byte{2 + 1 + 5})
emit.Opcode(c.prog.BinWriter, opcode.PUSHF) emit.Opcode(c.prog.BinWriter, opcode.PUSHF)
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end) emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
c.setLabel(next)
ast.Walk(c, n.Y) ast.Walk(c, n.Y)
c.setLabel(end) c.setLabel(end)
return nil return nil
case token.LOR: case token.LOR:
next := c.newLabel()
end := c.newLabel() end := c.newLabel()
ast.Walk(c, n.X) ast.Walk(c, n.X)
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, next) emit.Instruction(c.prog.BinWriter, opcode.JMPIFNOT, []byte{2 + 1 + 5})
emit.Opcode(c.prog.BinWriter, opcode.PUSHT) emit.Opcode(c.prog.BinWriter, opcode.PUSHT)
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end) emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
c.setLabel(next)
ast.Walk(c, n.Y) ast.Walk(c, n.Y)
c.setLabel(end) c.setLabel(end)
return nil return nil
@ -1613,8 +1609,8 @@ func CodeGen(info *buildInfo) ([]byte, *DebugInfo, error) {
return nil, nil, err return nil, nil, err
} }
buf := c.prog.Bytes() buf, err := c.writeJumps(c.prog.Bytes())
if err := c.writeJumps(buf); err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return buf, c.emitDebugInfo(buf), nil return buf, c.emitDebugInfo(buf), nil
@ -1629,15 +1625,14 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
} }
} }
func (c *codegen) writeJumps(b []byte) error { func (c *codegen) writeJumps(b []byte) ([]byte, error) {
ctx := vm.NewContext(b) ctx := vm.NewContext(b)
for op, _, err := ctx.Next(); err == nil && ctx.NextIP() < len(b); op, _, err = ctx.Next() { var offsets []int
for op, _, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, _, err = ctx.Next() {
switch op { switch op {
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
opcode.JMPEQ, opcode.JMPNE, opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
// Noop, assumed to be correct already. If you're fixing #905,
// make sure not to break "len" and "append" handling above.
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL, opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
@ -1648,15 +1643,135 @@ func (c *codegen) writeJumps(b []byte) error {
index := binary.LittleEndian.Uint16(arg) index := binary.LittleEndian.Uint16(arg)
if int(index) > len(c.l) { if int(index) > len(c.l) {
return fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) return nil, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l))
} }
offset := c.l[index] - nextIP + 5 offset := c.l[index] - nextIP + 5
if offset > math.MaxInt32 || offset < math.MinInt32 { if offset > math.MaxInt32 || offset < math.MinInt32 {
return fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", return nil, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)",
nextIP-5, offset, math.MaxInt32, math.MinInt32) nextIP-5, offset, math.MaxInt32, math.MinInt32)
} }
if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 {
offsets = append(offsets, ctx.IP())
}
binary.LittleEndian.PutUint32(arg, uint32(offset)) binary.LittleEndian.PutUint32(arg, uint32(offset))
} }
} }
return nil // Correct function ip range.
// Note: indices are sorted in increasing order.
for _, f := range c.funcs {
loop:
for _, ind := range offsets {
switch {
case ind > int(f.rng.End):
break loop
case ind < int(f.rng.Start):
f.rng.Start -= longToShortRemoveCount
f.rng.End -= longToShortRemoveCount
case ind >= int(f.rng.Start):
f.rng.End -= longToShortRemoveCount
}
}
}
return shortenJumps(b, offsets), nil
}
// longToShortRemoveCount is a difference between short and long instruction sizes in bytes.
const longToShortRemoveCount = 3
// shortenJumps returns converts b to a program where all long JMP*/CALL* specified by absolute offsets,
// are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid.
// This is done in 2 passes:
// 1. Alter jump offsets taking into account parts to be removed.
// 2. Perform actual removal of jump targets.
// Note: after jump offsets altering, there can appear new candidates for conversion.
// These are ignored for now.
func shortenJumps(b []byte, offsets []int) []byte {
if len(offsets) == 0 {
return b
}
// 1. Alter existing jump offsets.
ctx := vm.NewContext(b)
for op, _, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, _, err = ctx.Next() {
// we can't use arg returned by ctx.Next() because it is copied
nextIP := ctx.NextIP()
ip := ctx.IP()
switch op {
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
offset := int(int8(b[nextIP-1]))
offset += calcOffsetCorrection(ip, ip+offset, offsets)
b[nextIP-1] = byte(offset)
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
opcode.CALLL, opcode.PUSHA:
arg := b[nextIP-4:]
offset := int(int32(binary.LittleEndian.Uint32(arg)))
offset += calcOffsetCorrection(ip, ip+offset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(offset))
}
}
// 2. Convert instructions.
copyOffset := 0
l := len(offsets)
b[offsets[0]] = toShortForm(b[offsets[0]])
for i := 0; i < l; i++ {
start := offsets[i] + 2
end := len(b)
if i != l-1 {
end = offsets[i+1] + 2
b[offsets[i+1]] = toShortForm(b[offsets[i+1]])
}
copy(b[start-copyOffset:], b[start+3:end])
copyOffset += longToShortRemoveCount
}
return b[:len(b)-copyOffset]
}
func calcOffsetCorrection(ip, target int, offsets []int) int {
cnt := 0
start := sort.Search(len(offsets), func(i int) bool {
return offsets[i] >= ip || offsets[i] >= target
})
for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ {
ind := offsets[i]
if ip <= ind && ind < target ||
ind != ip && target <= ind && ind <= ip {
cnt += longToShortRemoveCount
}
}
if ip < target {
return -cnt
}
return cnt
}
func toShortForm(b byte) byte {
switch op := opcode.Opcode(b); op {
case opcode.JMPL:
return byte(opcode.JMP)
case opcode.JMPIFL:
return byte(opcode.JMPIF)
case opcode.JMPIFNOTL:
return byte(opcode.JMPIFNOT)
case opcode.JMPEQL:
return byte(opcode.JMPEQ)
case opcode.JMPNEL:
return byte(opcode.JMPNE)
case opcode.JMPGTL:
return byte(opcode.JMPGT)
case opcode.JMPGEL:
return byte(opcode.JMPGE)
case opcode.JMPLEL:
return byte(opcode.JMPLE)
case opcode.JMPLTL:
return byte(opcode.JMPLT)
case opcode.CALLL:
return byte(opcode.CALL)
default:
panic(fmt.Errorf("invalid opcode: %s", op))
}
} }

View file

@ -221,7 +221,7 @@ func TestBuiltinDoesNotCompile(t *testing.T) {
ctx := v.Context() ctx := v.Context()
retCount := 0 retCount := 0
for op, _, err := ctx.Next(); err == nil; op, _, err = ctx.Next() { for op, _, err := ctx.Next(); err == nil; op, _, err = ctx.Next() {
if ctx.IP() > len(ctx.Program()) { if ctx.IP() >= len(ctx.Program()) {
break break
} }
if op == opcode.RET { if op == opcode.RET {

129
pkg/compiler/jumps_test.go Normal file
View file

@ -0,0 +1,129 @@
package compiler
import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/stretchr/testify/require"
)
func testShortenJumps(t *testing.T, before, after []opcode.Opcode, indices []int) {
prog := make([]byte, len(before))
for i := range before {
prog[i] = byte(before[i])
}
raw := shortenJumps(prog, indices)
actual := make([]opcode.Opcode, len(raw))
for i := range raw {
actual[i] = opcode.Opcode(raw[i])
}
require.Equal(t, after, actual)
}
func TestShortenJumps(t *testing.T) {
testCases := map[opcode.Opcode]opcode.Opcode{
opcode.JMPL: opcode.JMP,
opcode.JMPIFL: opcode.JMPIF,
opcode.JMPIFNOTL: opcode.JMPIFNOT,
opcode.JMPEQL: opcode.JMPEQ,
opcode.JMPNEL: opcode.JMPNE,
opcode.JMPGTL: opcode.JMPGT,
opcode.JMPGEL: opcode.JMPGE,
opcode.JMPLEL: opcode.JMPLE,
opcode.JMPLTL: opcode.JMPLT,
opcode.CALLL: opcode.CALL,
}
for op, sop := range testCases {
t.Run(op.String(), func(t *testing.T) {
before := []opcode.Opcode{
op, 6, 0, 0, 0, opcode.PUSH1, opcode.NOP, // <- first jump to here
op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here
op, 255, 0, 0, 0, op, 0xFF - 5, 0xFF, 0xFF, 0xFF,
}
after := []opcode.Opcode{
sop, 3, opcode.PUSH1, opcode.NOP,
op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP,
sop, 249, sop, 0xFF - 2,
}
testShortenJumps(t, before, after, []int{0, 14, 19})
})
}
t.Run("NoReplace", func(t *testing.T) {
b := []byte{0, 1, 2, 3, 4, 5}
expected := []byte{0, 1, 2, 3, 4, 5}
require.Equal(t, expected, shortenJumps(b, nil))
})
t.Run("InvalidIndex", func(t *testing.T) {
before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0}
require.Panics(t, func() {
shortenJumps(before, []int{0})
})
})
t.Run("SideConditions", func(t *testing.T) {
t.Run("Forward", func(t *testing.T) {
before := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0,
opcode.JMPL, 5, 0, 0, 0,
}
after := []opcode.Opcode{
opcode.JMP, 2,
opcode.JMP, 2,
}
testShortenJumps(t, before, after, []int{0, 5})
})
t.Run("Backwards", func(t *testing.T) {
before := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0,
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF,
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF,
}
after := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0,
opcode.JMP, 0xFF - 4,
opcode.JMP, 0xFF - 1,
}
testShortenJumps(t, before, after, []int{5, 10})
})
})
}
func TestWriteJumps(t *testing.T) {
c := new(codegen)
c.l = []int{10}
before := []byte{
byte(opcode.NOP), byte(opcode.JMP), 2, byte(opcode.RET),
byte(opcode.CALLL), 0, 0, 0, 0, byte(opcode.RET),
byte(opcode.PUSH2), byte(opcode.RET),
}
c.funcs = map[string]*funcScope{
"init": {rng: DebugRange{Start: 0, End: 3}},
"main": {rng: DebugRange{Start: 4, End: 9}},
"method": {rng: DebugRange{Start: 10, End: 11}},
}
expProg := []byte{
byte(opcode.NOP), byte(opcode.JMP), 2, byte(opcode.RET),
byte(opcode.CALL), 3, byte(opcode.RET),
byte(opcode.PUSH2), byte(opcode.RET),
}
expFuncs := map[string]*funcScope{
"init": {rng: DebugRange{Start: 0, End: 3}},
"main": {rng: DebugRange{Start: 4, End: 6}},
"method": {rng: DebugRange{Start: 7, End: 8}},
}
buf, err := c.writeJumps(before)
require.NoError(t, err)
require.Equal(t, expProg, buf)
require.Equal(t, expFuncs, c.funcs)
}
func TestWriteJumpsLastJump(t *testing.T) {
c := new(codegen)
c.l = []int{2}
prog := []byte{byte(opcode.JMP), 3, byte(opcode.RET), byte(opcode.JMPL), 0, 0, 0, 0}
expected := []byte{byte(opcode.JMP), 3, byte(opcode.RET), byte(opcode.JMP), 0xFF}
actual, err := c.writeJumps(prog)
require.NoError(t, err)
require.Equal(t, expected, actual)
}

View file

@ -433,8 +433,8 @@ func handleOps(c *ishell.Context) {
} }
func changePrompt(c ishell.Actions, v *vm.VM) { func changePrompt(c ishell.Actions, v *vm.VM) {
if v.Ready() && v.Context().IP()-1 >= 0 { if v.Ready() && v.Context().IP() >= 0 {
c.SetPrompt(fmt.Sprintf("NEO-GO-VM %d > ", v.Context().IP()-1)) c.SetPrompt(fmt.Sprintf("NEO-GO-VM %d > ", v.Context().IP()))
} else { } else {
c.SetPrompt("NEO-GO-VM > ") c.SetPrompt("NEO-GO-VM > ")
} }

View file

@ -141,11 +141,9 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) {
return instr, parameter, nil return instr, parameter, nil
} }
// IP returns the absolute instruction without taking 0 into account. // IP returns current instruction offset in the context script.
// If that program starts the ip = 0 but IP() will return 1, cause its
// the first instruction.
func (c *Context) IP() int { func (c *Context) IP() int {
return c.ip + 1 return c.ip
} }
// LenInstr returns the number of instructions loaded. // LenInstr returns the number of instructions loaded.