diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 47afc481a..20f558d71 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -1613,8 +1613,8 @@ func CodeGen(info *buildInfo) ([]byte, *DebugInfo, error) { return nil, nil, err } - buf := c.prog.Bytes() - if err := c.writeJumps(buf); err != nil { + buf, err := c.writeJumps(c.prog.Bytes()) + if err != nil { return nil, nil, err } return buf, c.emitDebugInfo(buf), nil @@ -1629,15 +1629,14 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) { } } -func (c *codegen) writeJumps(b []byte) error { +func (c *codegen) writeJumps(b []byte) ([]byte, error) { ctx := vm.NewContext(b) + var offsets []int for op, _, err := ctx.Next(); err == nil && ctx.NextIP() < len(b); op, _, err = ctx.Next() { switch op { case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, opcode.JMPEQ, opcode.JMPNE, opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: - // Noop, assumed to be correct already. If you're fixing #905, - // make sure not to break "len" and "append" handling above. case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, @@ -1648,15 +1647,135 @@ func (c *codegen) writeJumps(b []byte) error { index := binary.LittleEndian.Uint16(arg) if int(index) > len(c.l) { - return fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) + return nil, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) } offset := c.l[index] - nextIP + 5 if offset > math.MaxInt32 || offset < math.MinInt32 { - return fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", + return nil, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", nextIP-5, offset, math.MaxInt32, math.MinInt32) } + if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 { + offsets = append(offsets, ctx.IP()) + } binary.LittleEndian.PutUint32(arg, uint32(offset)) } } - return nil + // Correct function ip range. + // Note: indices are sorted in increasing order. + for _, f := range c.funcs { + loop: + for _, ind := range offsets { + switch { + case ind > int(f.rng.End): + break loop + case ind < int(f.rng.Start): + f.rng.Start -= longToShortRemoveCount + f.rng.End -= longToShortRemoveCount + case ind >= int(f.rng.Start): + f.rng.End -= longToShortRemoveCount + } + } + } + return shortenJumps(b, offsets), nil +} + +// longToShortRemoveCount is a difference between short and long instruction sizes in bytes. +const longToShortRemoveCount = 3 + +// shortenJumps returns converts b to a program where all long JMP*/CALL* specified by absolute offsets, +// are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid. +// This is done in 2 passes: +// 1. Alter jump offsets taking into account parts to be removed. +// 2. Perform actual removal of jump targets. +// Note: after jump offsets altering, there can appear new candidates for conversion. +// These are ignored for now. +func shortenJumps(b []byte, offsets []int) []byte { + if len(offsets) == 0 { + return b + } + + // 1. Alter existing jump offsets. + ctx := vm.NewContext(b) + for op, _, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, _, err = ctx.Next() { + // we can't use arg returned by ctx.Next() because it is copied + nextIP := ctx.NextIP() + ip := ctx.IP() + switch op { + case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, + opcode.JMPEQ, opcode.JMPNE, + opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: + offset := int(int8(b[nextIP-1])) + offset += calcOffsetCorrection(ip, ip+offset, offsets) + b[nextIP-1] = byte(offset) + case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, + opcode.JMPEQL, opcode.JMPNEL, + opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, + opcode.CALLL, opcode.PUSHA: + arg := b[nextIP-4:] + offset := int(int32(binary.LittleEndian.Uint32(arg))) + offset += calcOffsetCorrection(ip, ip+offset, offsets) + binary.LittleEndian.PutUint32(arg, uint32(offset)) + } + } + + // 2. Convert instructions. + copyOffset := 0 + l := len(offsets) + b[offsets[0]] = toShortForm(b[offsets[0]]) + for i := 0; i < l; i++ { + start := offsets[i] + 2 + end := len(b) + if i != l-1 { + end = offsets[i+1] + 2 + b[offsets[i+1]] = toShortForm(b[offsets[i+1]]) + } + copy(b[start-copyOffset:], b[start+3:end]) + copyOffset += longToShortRemoveCount + } + return b[:len(b)-copyOffset] +} + +func calcOffsetCorrection(ip, target int, offsets []int) int { + cnt := 0 + start := sort.Search(len(offsets), func(i int) bool { + return offsets[i] >= ip || offsets[i] >= target + }) + for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ { + ind := offsets[i] + if ip <= ind && ind < target || + ind != ip && target <= ind && ind <= ip { + cnt += longToShortRemoveCount + } + } + if ip < target { + return -cnt + } + return cnt +} + +func toShortForm(b byte) byte { + switch op := opcode.Opcode(b); op { + case opcode.JMPL: + return byte(opcode.JMP) + case opcode.JMPIFL: + return byte(opcode.JMPIF) + case opcode.JMPIFNOTL: + return byte(opcode.JMPIFNOT) + case opcode.JMPEQL: + return byte(opcode.JMPEQ) + case opcode.JMPNEL: + return byte(opcode.JMPNE) + case opcode.JMPGTL: + return byte(opcode.JMPGT) + case opcode.JMPGEL: + return byte(opcode.JMPGE) + case opcode.JMPLEL: + return byte(opcode.JMPLE) + case opcode.JMPLTL: + return byte(opcode.JMPLT) + case opcode.CALLL: + return byte(opcode.CALL) + default: + panic(fmt.Errorf("invalid opcode: %s", op)) + } } diff --git a/pkg/compiler/jumps_test.go b/pkg/compiler/jumps_test.go new file mode 100644 index 000000000..cdfce463e --- /dev/null +++ b/pkg/compiler/jumps_test.go @@ -0,0 +1,119 @@ +package compiler + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/stretchr/testify/require" +) + +func testShortenJumps(t *testing.T, before, after []opcode.Opcode, indices []int) { + prog := make([]byte, len(before)) + for i := range before { + prog[i] = byte(before[i]) + } + raw := shortenJumps(prog, indices) + actual := make([]opcode.Opcode, len(raw)) + for i := range raw { + actual[i] = opcode.Opcode(raw[i]) + } + require.Equal(t, after, actual) +} + +func TestShortenJumps(t *testing.T) { + testCases := map[opcode.Opcode]opcode.Opcode{ + opcode.JMPL: opcode.JMP, + opcode.JMPIFL: opcode.JMPIF, + opcode.JMPIFNOTL: opcode.JMPIFNOT, + opcode.JMPEQL: opcode.JMPEQ, + opcode.JMPNEL: opcode.JMPNE, + opcode.JMPGTL: opcode.JMPGT, + opcode.JMPGEL: opcode.JMPGE, + opcode.JMPLEL: opcode.JMPLE, + opcode.JMPLTL: opcode.JMPLT, + opcode.CALLL: opcode.CALL, + } + for op, sop := range testCases { + t.Run(op.String(), func(t *testing.T) { + before := []opcode.Opcode{ + op, 6, 0, 0, 0, opcode.PUSH1, opcode.NOP, // <- first jump to here + op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here + op, 255, 0, 0, 0, op, 0xFF - 5, 0xFF, 0xFF, 0xFF, + } + after := []opcode.Opcode{ + sop, 3, opcode.PUSH1, opcode.NOP, + op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP, + sop, 249, sop, 0xFF - 2, + } + testShortenJumps(t, before, after, []int{0, 14, 19}) + }) + } + t.Run("NoReplace", func(t *testing.T) { + b := []byte{0, 1, 2, 3, 4, 5} + expected := []byte{0, 1, 2, 3, 4, 5} + require.Equal(t, expected, shortenJumps(b, nil)) + }) + t.Run("InvalidIndex", func(t *testing.T) { + before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0} + require.Panics(t, func() { + shortenJumps(before, []int{0}) + }) + }) + t.Run("SideConditions", func(t *testing.T) { + t.Run("Forward", func(t *testing.T) { + before := []opcode.Opcode{ + opcode.JMPL, 5, 0, 0, 0, + opcode.JMPL, 5, 0, 0, 0, + } + after := []opcode.Opcode{ + opcode.JMP, 2, + opcode.JMP, 2, + } + testShortenJumps(t, before, after, []int{0, 5}) + }) + t.Run("Backwards", func(t *testing.T) { + before := []opcode.Opcode{ + opcode.JMPL, 5, 0, 0, 0, + opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF, + opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF, + } + after := []opcode.Opcode{ + opcode.JMPL, 5, 0, 0, 0, + opcode.JMP, 0xFF - 4, + opcode.JMP, 0xFF - 1, + } + testShortenJumps(t, before, after, []int{5, 10}) + }) + }) +} + +func TestWriteJumps(t *testing.T) { + c := new(codegen) + c.l = []int{10} + before := []byte{ + byte(opcode.NOP), byte(opcode.JMP), 2, byte(opcode.RET), + byte(opcode.CALLL), 0, 0, 0, 0, byte(opcode.RET), + byte(opcode.PUSH2), byte(opcode.RET), + } + c.funcs = map[string]*funcScope{ + "init": {rng: DebugRange{Start: 0, End: 3}}, + "main": {rng: DebugRange{Start: 4, End: 9}}, + "method": {rng: DebugRange{Start: 10, End: 11}}, + } + + expProg := []byte{ + byte(opcode.NOP), byte(opcode.JMP), 2, byte(opcode.RET), + byte(opcode.CALL), 3, byte(opcode.RET), + byte(opcode.PUSH2), byte(opcode.RET), + } + expFuncs := map[string]*funcScope{ + "init": {rng: DebugRange{Start: 0, End: 3}}, + "main": {rng: DebugRange{Start: 4, End: 6}}, + "method": {rng: DebugRange{Start: 7, End: 8}}, + } + + buf, err := c.writeJumps(before) + require.NoError(t, err) + require.Equal(t, expProg, buf) + require.Equal(t, expFuncs, c.funcs) +}