diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index be22efec4..33cec05f3 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -62,9 +62,8 @@ type codegen struct { labels map[labelWithType]uint16 // A list of nested label names together with evaluation stack depth. labelList []labelWithStackSize - // inlineLabelOffsets contains size of labelList at the start of inline call processing. - // For such calls, we need to drop only the newly created part of stack. - inlineLabelOffsets []int + // inlineContext contains info about inlined function calls. + inlineContext []inlineContextSingle // globalInlineCount contains the amount of auxiliary variables introduced by // function inlining during global variables initialization. globalInlineCount int @@ -146,6 +145,14 @@ type nameWithLocals struct { count int } +type inlineContextSingle struct { + // labelOffset contains size of labelList at the start of inline call processing. + // For such calls, we need to drop only the newly created part of stack. + labelOffset int + // returnLabel contains label ID pointing to the first instruction right after the call. + returnLabel uint16 +} + type varType int const ( @@ -680,8 +687,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { cnt := 0 start := 0 - if len(c.inlineLabelOffsets) > 0 { - start = c.inlineLabelOffsets[len(c.inlineLabelOffsets)-1] + if len(c.inlineContext) > 0 { + start = c.inlineContext[len(c.inlineContext)-1].labelOffset } for i := start; i < len(c.labelList); i++ { cnt += c.labelList[i].sz @@ -711,6 +718,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.saveSequencePoint(n) if len(c.pkgInfoInline) == 0 { emit.Opcodes(c.prog.BinWriter, opcode.RET) + } else { + emit.Jmp(c.prog.BinWriter, opcode.JMPL, c.inlineContext[len(c.inlineContext)-1].returnLabel) } return nil @@ -2211,7 +2220,7 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) { func (c *codegen) writeJumps(b []byte) ([]byte, error) { ctx := vm.NewContext(b) - var offsets []int + var nopOffsets []int for op, param, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, param, err = ctx.Next() { switch op { case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, @@ -2235,13 +2244,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { return nil, err } if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 { - offsets = append(offsets, ctx.IP()) + if op == opcode.JMPL && offset == 5 { + copy(b[ctx.IP():], []byte{byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)}) + nopOffsets = append(nopOffsets, ctx.IP(), ctx.IP()+1, ctx.IP()+2, ctx.IP()+3, ctx.IP()+4) + } else { + copy(b[ctx.IP():], []byte{byte(toShortForm(op)), byte(offset), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)}) + nopOffsets = append(nopOffsets, ctx.IP()+2, ctx.IP()+3, ctx.IP()+4) + } } case opcode.INITSLOT: nextIP := ctx.NextIP() info := c.reverseOffsetMap[ctx.IP()] if argCount := b[nextIP-1]; info.count == 0 && argCount == 0 { - offsets = append(offsets, ctx.IP()) + copy(b[ctx.IP():], []byte{byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)}) + nopOffsets = append(nopOffsets, ctx.IP(), ctx.IP()+1, ctx.IP()+2) continue } @@ -2253,20 +2269,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { } if c.deployEndOffset >= 0 { - _, end := correctRange(uint16(c.initEndOffset+1), uint16(c.deployEndOffset), offsets) + _, end := correctRange(uint16(c.initEndOffset+1), uint16(c.deployEndOffset), nopOffsets) c.deployEndOffset = int(end) } if c.initEndOffset > 0 { - _, end := correctRange(0, uint16(c.initEndOffset), offsets) + _, end := correctRange(0, uint16(c.initEndOffset), nopOffsets) c.initEndOffset = int(end) } // Correct function ip range. // Note: indices are sorted in increasing order. for _, f := range c.funcs { - f.rng.Start, f.rng.End = correctRange(f.rng.Start, f.rng.End, offsets) + f.rng.Start, f.rng.End = correctRange(f.rng.Start, f.rng.End, nopOffsets) } - return shortenJumps(b, offsets), nil + return removeNOPs(b, nopOffsets), nil } func correctRange(start, end uint16, offsets []int) (uint16, uint16) { @@ -2277,10 +2293,10 @@ loop: case ind > int(end): break loop case ind < int(start): - newStart -= longToShortRemoveCount - newEnd -= longToShortRemoveCount + newStart-- + newEnd-- case ind >= int(start): - newEnd -= longToShortRemoveCount + newEnd-- } } return newStart, newEnd @@ -2303,21 +2319,22 @@ func (c *codegen) replaceLabelWithOffset(ip int, arg []byte) (int, error) { return offset, nil } -// longToShortRemoveCount is a difference between short and long instruction sizes in bytes. -// By pure coincidence, this is also the size of `INITSLOT` instruction. -const longToShortRemoveCount = 3 - -// shortenJumps converts b to a program where all long JMP*/CALL* specified by absolute offsets +// removeNOPs converts b to a program where all long JMP*/CALL* specified by absolute offsets // are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid. // This is done in 2 passes: // 1. Alter jump offsets taking into account parts to be removed. // 2. Perform actual removal of jump targets. // Note: after jump offsets altering, there can appear new candidates for conversion. // These are ignored for now. -func shortenJumps(b []byte, offsets []int) []byte { - if len(offsets) == 0 { +func removeNOPs(b []byte, nopOffsets []int) []byte { + if len(nopOffsets) == 0 { return b } + for i := range nopOffsets { + if b[nopOffsets[i]] != byte(opcode.NOP) { + panic("NOP offset is invalid") + } + } // 1. Alter existing jump offsets. ctx := vm.NewContext(b) @@ -2330,14 +2347,14 @@ func shortenJumps(b []byte, offsets []int) []byte { opcode.JMPEQ, opcode.JMPNE, opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY: offset := int(int8(b[nextIP-1])) - offset += calcOffsetCorrection(ip, ip+offset, offsets) + offset += calcOffsetCorrection(ip, ip+offset, nopOffsets) b[nextIP-1] = byte(offset) case opcode.TRY: catchOffset := int(int8(b[nextIP-2])) - catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) + catchOffset += calcOffsetCorrection(ip, ip+catchOffset, nopOffsets) b[nextIP-1] = byte(catchOffset) finallyOffset := int(int8(b[nextIP-1])) - finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) + finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, nopOffsets) b[nextIP-1] = byte(finallyOffset) case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, @@ -2345,42 +2362,31 @@ func shortenJumps(b []byte, offsets []int) []byte { opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL: arg := b[nextIP-4:] offset := int(int32(binary.LittleEndian.Uint32(arg))) - offset += calcOffsetCorrection(ip, ip+offset, offsets) + offset += calcOffsetCorrection(ip, ip+offset, nopOffsets) binary.LittleEndian.PutUint32(arg, uint32(offset)) case opcode.TRYL: arg := b[nextIP-8:] catchOffset := int(int32(binary.LittleEndian.Uint32(arg))) - catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) + catchOffset += calcOffsetCorrection(ip, ip+catchOffset, nopOffsets) binary.LittleEndian.PutUint32(arg, uint32(catchOffset)) arg = b[nextIP-4:] finallyOffset := int(int32(binary.LittleEndian.Uint32(arg))) - finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) + finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, nopOffsets) binary.LittleEndian.PutUint32(arg, uint32(finallyOffset)) } } // 2. Convert instructions. copyOffset := 0 - l := len(offsets) - if op := opcode.Opcode(b[offsets[0]]); op != opcode.INITSLOT { - b[offsets[0]] = byte(toShortForm(op)) - } + l := len(nopOffsets) for i := 0; i < l; i++ { - start := offsets[i] + 2 - if b[offsets[i]] == byte(opcode.INITSLOT) { - start = offsets[i] - } - + start := nopOffsets[i] end := len(b) if i != l-1 { - end = offsets[i+1] - if op := opcode.Opcode(b[offsets[i+1]]); op != opcode.INITSLOT { - end += 2 - b[offsets[i+1]] = byte(toShortForm(op)) - } + end = nopOffsets[i+1] } - copy(b[start-copyOffset:], b[start+3:end]) - copyOffset += longToShortRemoveCount + copy(b[start-copyOffset:], b[start+1:end]) + copyOffset++ } return b[:len(b)-copyOffset] } @@ -2392,9 +2398,8 @@ func calcOffsetCorrection(ip, target int, offsets []int) int { }) for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ { ind := offsets[i] - if ip <= ind && ind < target || - ind != ip && target <= ind && ind <= ip { - cnt += longToShortRemoveCount + if ip <= ind && ind < target || target <= ind && ind < ip { + cnt++ } } if ip < target { diff --git a/pkg/compiler/inline.go b/pkg/compiler/inline.go index 57b0ed707..115eb879e 100644 --- a/pkg/compiler/inline.go +++ b/pkg/compiler/inline.go @@ -21,12 +21,15 @@ import ( // // } func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) { - labelSz := len(c.labelList) - offSz := len(c.inlineLabelOffsets) - c.inlineLabelOffsets = append(c.inlineLabelOffsets, labelSz) + offSz := len(c.inlineContext) + c.inlineContext = append(c.inlineContext, inlineContextSingle{ + labelOffset: len(c.labelList), + returnLabel: c.newLabel(), + }) + defer func() { - c.inlineLabelOffsets = c.inlineLabelOffsets[:offSz] - c.labelList = c.labelList[:labelSz] + c.labelList = c.labelList[:c.inlineContext[offSz].labelOffset] + c.inlineContext = c.inlineContext[:offSz] }() pkg := c.packageCache[f.pkg.Path()] @@ -113,6 +116,7 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) { c.fillImportMap(f.file, pkg) ast.Inspect(f.decl, c.scope.analyzeVoidCalls) ast.Walk(c, f.decl.Body) + c.setLabel(c.inlineContext[offSz].returnLabel) if c.scope.voidCalls[n] { for i := 0; i < f.decl.Type.Results.NumFields(); i++ { emit.Opcodes(c.prog.BinWriter, opcode.DROP) diff --git a/pkg/compiler/inline_test.go b/pkg/compiler/inline_test.go index b68a816d3..9de0379aa 100644 --- a/pkg/compiler/inline_test.go +++ b/pkg/compiler/inline_test.go @@ -374,3 +374,46 @@ func TestInlinedMethodWithPointer(t *testing.T) { }` eval(t, src, big.NewInt(100542)) } + +func TestInlineConditionalReturn(t *testing.T) { + srcTmpl := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/c" + func Main() int { + x := %d + if c.Is42(x) { + return 100 + } + return 10 + }` + t.Run("true", func(t *testing.T) { + eval(t, fmt.Sprintf(srcTmpl, 123), big.NewInt(10)) + }) + t.Run("false", func(t *testing.T) { + eval(t, fmt.Sprintf(srcTmpl, 42), big.NewInt(100)) + }) +} + +func TestInlineDoubleConditionalReturn(t *testing.T) { + srcTmpl := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/c" + func Main() int { + return c.Transform(%d, %d) + }` + + testCase := []struct { + name string + a, b, result int + }{ + {"true, true, small", 42, 3, 6}, + {"true, true, big", 42, 15, 15}, + {"true, false", 42, 42, 42}, + {"false, true", 3, 11, 6}, + {"false, false", 3, 42, 6}, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + eval(t, fmt.Sprintf(srcTmpl, tc.a, tc.b), big.NewInt(int64(tc.result))) + }) + } +} diff --git a/pkg/compiler/jumps_test.go b/pkg/compiler/jumps_test.go index 7f18a1c05..9772138ee 100644 --- a/pkg/compiler/jumps_test.go +++ b/pkg/compiler/jumps_test.go @@ -12,7 +12,7 @@ func testShortenJumps(t *testing.T, before, after []opcode.Opcode, indices []int for i := range before { prog[i] = byte(before[i]) } - raw := shortenJumps(prog, indices) + raw := removeNOPs(prog, indices) actual := make([]opcode.Opcode, len(raw)) for i := range raw { actual[i] = opcode.Opcode(raw[i]) @@ -36,53 +36,53 @@ func TestShortenJumps(t *testing.T) { for op, sop := range testCases { t.Run(op.String(), func(t *testing.T) { before := []opcode.Opcode{ - op, 6, 0, 0, 0, opcode.PUSH1, opcode.NOP, // <- first jump to here + sop, 6, opcode.NOP, opcode.NOP, opcode.NOP, opcode.PUSH1, opcode.NOP, // <- first jump to here op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here - op, 255, 0, 0, 0, op, 0xFF - 5, 0xFF, 0xFF, 0xFF, + sop, 249, opcode.NOP, opcode.NOP, opcode.NOP, sop, 0xFF - 5, opcode.NOP, opcode.NOP, opcode.NOP, } after := []opcode.Opcode{ sop, 3, opcode.PUSH1, opcode.NOP, op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP, sop, 249, sop, 0xFF - 2, } - testShortenJumps(t, before, after, []int{0, 14, 19}) + testShortenJumps(t, before, after, []int{2, 3, 4, 16, 17, 18, 21, 22, 23}) }) } t.Run("NoReplace", func(t *testing.T) { b := []byte{0, 1, 2, 3, 4, 5} expected := []byte{0, 1, 2, 3, 4, 5} - require.Equal(t, expected, shortenJumps(b, nil)) + require.Equal(t, expected, removeNOPs(b, nil)) }) t.Run("InvalidIndex", func(t *testing.T) { before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0} require.Panics(t, func() { - shortenJumps(before, []int{0}) + removeNOPs(before, []int{0}) }) }) t.Run("SideConditions", func(t *testing.T) { t.Run("Forward", func(t *testing.T) { before := []opcode.Opcode{ - opcode.JMPL, 5, 0, 0, 0, - opcode.JMPL, 5, 0, 0, 0, + opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP, + opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP, } after := []opcode.Opcode{ opcode.JMP, 2, opcode.JMP, 2, } - testShortenJumps(t, before, after, []int{0, 5}) + testShortenJumps(t, before, after, []int{2, 3, 4, 7, 8, 9}) }) t.Run("Backwards", func(t *testing.T) { before := []opcode.Opcode{ - opcode.JMPL, 5, 0, 0, 0, - opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF, - opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF, + opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP, + opcode.JMP, 0xFF - 4, opcode.NOP, opcode.NOP, opcode.NOP, + opcode.JMP, 0xFF - 4, opcode.NOP, opcode.NOP, opcode.NOP, } after := []opcode.Opcode{ - opcode.JMPL, 5, 0, 0, 0, - opcode.JMP, 0xFF - 4, + opcode.JMP, 2, + opcode.JMP, 0xFF - 1, opcode.JMP, 0xFF - 1, } - testShortenJumps(t, before, after, []int{5, 10}) + testShortenJumps(t, before, after, []int{2, 3, 4, 7, 8, 9, 12, 13, 14}) }) }) } diff --git a/pkg/compiler/testdata/inline/c/null.go b/pkg/compiler/testdata/inline/c/null.go new file mode 100644 index 000000000..932dd669b --- /dev/null +++ b/pkg/compiler/testdata/inline/c/null.go @@ -0,0 +1,22 @@ +package c + +func Is42(a int) bool { + if a == 42 { + return true + } + return false +} + +func MulIfSmall(n int) int { + if n < 10 { + return n * 2 + } + return n +} + +func Transform(a, b int) int { + if Is42(a) && !Is42(b) { + return MulIfSmall(b) + } + return MulIfSmall(a) +}