compiler: reduce instructions in 2 stages

First replace parts to be removed with NOPs, then actually remove.

Signed-off-by: Evgeniy Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Evgeniy Stratonikov 2022-07-12 13:16:32 +03:00
parent ce24451fde
commit 05efc57485
2 changed files with 48 additions and 57 deletions

View file

@ -2220,7 +2220,7 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
func (c *codegen) writeJumps(b []byte) ([]byte, error) { func (c *codegen) writeJumps(b []byte) ([]byte, error) {
ctx := vm.NewContext(b) ctx := vm.NewContext(b)
var offsets []int var nopOffsets []int
for op, param, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, param, err = ctx.Next() { for op, param, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, param, err = ctx.Next() {
switch op { switch op {
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
@ -2244,13 +2244,15 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
return nil, err return nil, err
} }
if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 { if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 {
offsets = append(offsets, ctx.IP()) copy(b[ctx.IP():], []byte{byte(toShortForm(op)), byte(offset), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)})
nopOffsets = append(nopOffsets, ctx.IP()+2, ctx.IP()+3, ctx.IP()+4)
} }
case opcode.INITSLOT: case opcode.INITSLOT:
nextIP := ctx.NextIP() nextIP := ctx.NextIP()
info := c.reverseOffsetMap[ctx.IP()] info := c.reverseOffsetMap[ctx.IP()]
if argCount := b[nextIP-1]; info.count == 0 && argCount == 0 { if argCount := b[nextIP-1]; info.count == 0 && argCount == 0 {
offsets = append(offsets, ctx.IP()) copy(b[ctx.IP():], []byte{byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)})
nopOffsets = append(nopOffsets, ctx.IP(), ctx.IP()+1, ctx.IP()+2)
continue continue
} }
@ -2262,20 +2264,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
} }
if c.deployEndOffset >= 0 { if c.deployEndOffset >= 0 {
_, end := correctRange(uint16(c.initEndOffset+1), uint16(c.deployEndOffset), offsets) _, end := correctRange(uint16(c.initEndOffset+1), uint16(c.deployEndOffset), nopOffsets)
c.deployEndOffset = int(end) c.deployEndOffset = int(end)
} }
if c.initEndOffset > 0 { if c.initEndOffset > 0 {
_, end := correctRange(0, uint16(c.initEndOffset), offsets) _, end := correctRange(0, uint16(c.initEndOffset), nopOffsets)
c.initEndOffset = int(end) c.initEndOffset = int(end)
} }
// Correct function ip range. // Correct function ip range.
// Note: indices are sorted in increasing order. // Note: indices are sorted in increasing order.
for _, f := range c.funcs { for _, f := range c.funcs {
f.rng.Start, f.rng.End = correctRange(f.rng.Start, f.rng.End, offsets) f.rng.Start, f.rng.End = correctRange(f.rng.Start, f.rng.End, nopOffsets)
} }
return shortenJumps(b, offsets), nil return removeNOPs(b, nopOffsets), nil
} }
func correctRange(start, end uint16, offsets []int) (uint16, uint16) { func correctRange(start, end uint16, offsets []int) (uint16, uint16) {
@ -2286,10 +2288,10 @@ loop:
case ind > int(end): case ind > int(end):
break loop break loop
case ind < int(start): case ind < int(start):
newStart -= longToShortRemoveCount newStart--
newEnd -= longToShortRemoveCount newEnd--
case ind >= int(start): case ind >= int(start):
newEnd -= longToShortRemoveCount newEnd--
} }
} }
return newStart, newEnd return newStart, newEnd
@ -2312,21 +2314,22 @@ func (c *codegen) replaceLabelWithOffset(ip int, arg []byte) (int, error) {
return offset, nil return offset, nil
} }
// longToShortRemoveCount is a difference between short and long instruction sizes in bytes. // removeNOPs converts b to a program where all long JMP*/CALL* specified by absolute offsets
// By pure coincidence, this is also the size of `INITSLOT` instruction.
const longToShortRemoveCount = 3
// shortenJumps converts b to a program where all long JMP*/CALL* specified by absolute offsets
// are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid. // are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid.
// This is done in 2 passes: // This is done in 2 passes:
// 1. Alter jump offsets taking into account parts to be removed. // 1. Alter jump offsets taking into account parts to be removed.
// 2. Perform actual removal of jump targets. // 2. Perform actual removal of jump targets.
// Note: after jump offsets altering, there can appear new candidates for conversion. // Note: after jump offsets altering, there can appear new candidates for conversion.
// These are ignored for now. // These are ignored for now.
func shortenJumps(b []byte, offsets []int) []byte { func removeNOPs(b []byte, nopOffsets []int) []byte {
if len(offsets) == 0 { if len(nopOffsets) == 0 {
return b return b
} }
for i := range nopOffsets {
if b[nopOffsets[i]] != byte(opcode.NOP) {
panic("NOP offset is invalid")
}
}
// 1. Alter existing jump offsets. // 1. Alter existing jump offsets.
ctx := vm.NewContext(b) ctx := vm.NewContext(b)
@ -2339,14 +2342,14 @@ func shortenJumps(b []byte, offsets []int) []byte {
opcode.JMPEQ, opcode.JMPNE, opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY: opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY:
offset := int(int8(b[nextIP-1])) offset := int(int8(b[nextIP-1]))
offset += calcOffsetCorrection(ip, ip+offset, offsets) offset += calcOffsetCorrection(ip, ip+offset, nopOffsets)
b[nextIP-1] = byte(offset) b[nextIP-1] = byte(offset)
case opcode.TRY: case opcode.TRY:
catchOffset := int(int8(b[nextIP-2])) catchOffset := int(int8(b[nextIP-2]))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) catchOffset += calcOffsetCorrection(ip, ip+catchOffset, nopOffsets)
b[nextIP-1] = byte(catchOffset) b[nextIP-1] = byte(catchOffset)
finallyOffset := int(int8(b[nextIP-1])) finallyOffset := int(int8(b[nextIP-1]))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, nopOffsets)
b[nextIP-1] = byte(finallyOffset) b[nextIP-1] = byte(finallyOffset)
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL, opcode.JMPEQL, opcode.JMPNEL,
@ -2354,42 +2357,31 @@ func shortenJumps(b []byte, offsets []int) []byte {
opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL: opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL:
arg := b[nextIP-4:] arg := b[nextIP-4:]
offset := int(int32(binary.LittleEndian.Uint32(arg))) offset := int(int32(binary.LittleEndian.Uint32(arg)))
offset += calcOffsetCorrection(ip, ip+offset, offsets) offset += calcOffsetCorrection(ip, ip+offset, nopOffsets)
binary.LittleEndian.PutUint32(arg, uint32(offset)) binary.LittleEndian.PutUint32(arg, uint32(offset))
case opcode.TRYL: case opcode.TRYL:
arg := b[nextIP-8:] arg := b[nextIP-8:]
catchOffset := int(int32(binary.LittleEndian.Uint32(arg))) catchOffset := int(int32(binary.LittleEndian.Uint32(arg)))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) catchOffset += calcOffsetCorrection(ip, ip+catchOffset, nopOffsets)
binary.LittleEndian.PutUint32(arg, uint32(catchOffset)) binary.LittleEndian.PutUint32(arg, uint32(catchOffset))
arg = b[nextIP-4:] arg = b[nextIP-4:]
finallyOffset := int(int32(binary.LittleEndian.Uint32(arg))) finallyOffset := int(int32(binary.LittleEndian.Uint32(arg)))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, nopOffsets)
binary.LittleEndian.PutUint32(arg, uint32(finallyOffset)) binary.LittleEndian.PutUint32(arg, uint32(finallyOffset))
} }
} }
// 2. Convert instructions. // 2. Convert instructions.
copyOffset := 0 copyOffset := 0
l := len(offsets) l := len(nopOffsets)
if op := opcode.Opcode(b[offsets[0]]); op != opcode.INITSLOT {
b[offsets[0]] = byte(toShortForm(op))
}
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
start := offsets[i] + 2 start := nopOffsets[i]
if b[offsets[i]] == byte(opcode.INITSLOT) {
start = offsets[i]
}
end := len(b) end := len(b)
if i != l-1 { if i != l-1 {
end = offsets[i+1] end = nopOffsets[i+1]
if op := opcode.Opcode(b[offsets[i+1]]); op != opcode.INITSLOT {
end += 2
b[offsets[i+1]] = byte(toShortForm(op))
}
} }
copy(b[start-copyOffset:], b[start+3:end]) copy(b[start-copyOffset:], b[start+1:end])
copyOffset += longToShortRemoveCount copyOffset++
} }
return b[:len(b)-copyOffset] return b[:len(b)-copyOffset]
} }
@ -2401,9 +2393,8 @@ func calcOffsetCorrection(ip, target int, offsets []int) int {
}) })
for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ { for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ {
ind := offsets[i] ind := offsets[i]
if ip <= ind && ind < target || if ip <= ind && ind < target || target <= ind && ind < ip {
ind != ip && target <= ind && ind <= ip { cnt++
cnt += longToShortRemoveCount
} }
} }
if ip < target { if ip < target {

View file

@ -12,7 +12,7 @@ func testShortenJumps(t *testing.T, before, after []opcode.Opcode, indices []int
for i := range before { for i := range before {
prog[i] = byte(before[i]) prog[i] = byte(before[i])
} }
raw := shortenJumps(prog, indices) raw := removeNOPs(prog, indices)
actual := make([]opcode.Opcode, len(raw)) actual := make([]opcode.Opcode, len(raw))
for i := range raw { for i := range raw {
actual[i] = opcode.Opcode(raw[i]) actual[i] = opcode.Opcode(raw[i])
@ -36,53 +36,53 @@ func TestShortenJumps(t *testing.T) {
for op, sop := range testCases { for op, sop := range testCases {
t.Run(op.String(), func(t *testing.T) { t.Run(op.String(), func(t *testing.T) {
before := []opcode.Opcode{ before := []opcode.Opcode{
op, 6, 0, 0, 0, opcode.PUSH1, opcode.NOP, // <- first jump to here sop, 6, opcode.NOP, opcode.NOP, opcode.NOP, opcode.PUSH1, opcode.NOP, // <- first jump to here
op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here
op, 255, 0, 0, 0, op, 0xFF - 5, 0xFF, 0xFF, 0xFF, sop, 249, opcode.NOP, opcode.NOP, opcode.NOP, sop, 0xFF - 5, opcode.NOP, opcode.NOP, opcode.NOP,
} }
after := []opcode.Opcode{ after := []opcode.Opcode{
sop, 3, opcode.PUSH1, opcode.NOP, sop, 3, opcode.PUSH1, opcode.NOP,
op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP, op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP,
sop, 249, sop, 0xFF - 2, sop, 249, sop, 0xFF - 2,
} }
testShortenJumps(t, before, after, []int{0, 14, 19}) testShortenJumps(t, before, after, []int{2, 3, 4, 16, 17, 18, 21, 22, 23})
}) })
} }
t.Run("NoReplace", func(t *testing.T) { t.Run("NoReplace", func(t *testing.T) {
b := []byte{0, 1, 2, 3, 4, 5} b := []byte{0, 1, 2, 3, 4, 5}
expected := []byte{0, 1, 2, 3, 4, 5} expected := []byte{0, 1, 2, 3, 4, 5}
require.Equal(t, expected, shortenJumps(b, nil)) require.Equal(t, expected, removeNOPs(b, nil))
}) })
t.Run("InvalidIndex", func(t *testing.T) { t.Run("InvalidIndex", func(t *testing.T) {
before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0} before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0}
require.Panics(t, func() { require.Panics(t, func() {
shortenJumps(before, []int{0}) removeNOPs(before, []int{0})
}) })
}) })
t.Run("SideConditions", func(t *testing.T) { t.Run("SideConditions", func(t *testing.T) {
t.Run("Forward", func(t *testing.T) { t.Run("Forward", func(t *testing.T) {
before := []opcode.Opcode{ before := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0, opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP,
opcode.JMPL, 5, 0, 0, 0, opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP,
} }
after := []opcode.Opcode{ after := []opcode.Opcode{
opcode.JMP, 2, opcode.JMP, 2,
opcode.JMP, 2, opcode.JMP, 2,
} }
testShortenJumps(t, before, after, []int{0, 5}) testShortenJumps(t, before, after, []int{2, 3, 4, 7, 8, 9})
}) })
t.Run("Backwards", func(t *testing.T) { t.Run("Backwards", func(t *testing.T) {
before := []opcode.Opcode{ before := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0, opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP,
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF, opcode.JMP, 0xFF - 4, opcode.NOP, opcode.NOP, opcode.NOP,
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF, opcode.JMP, 0xFF - 4, opcode.NOP, opcode.NOP, opcode.NOP,
} }
after := []opcode.Opcode{ after := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0, opcode.JMP, 2,
opcode.JMP, 0xFF - 4, opcode.JMP, 0xFF - 1,
opcode.JMP, 0xFF - 1, opcode.JMP, 0xFF - 1,
} }
testShortenJumps(t, before, after, []int{5, 10}) testShortenJumps(t, before, after, []int{2, 3, 4, 7, 8, 9, 12, 13, 14})
}) })
}) })
} }