Merge pull request #1126 from nspcc-dev/fix/variable

compiler: support variable shadowing
This commit is contained in:
Roman Khimov 2020-06-30 11:27:52 +03:00 committed by GitHub
commit ee0d869815
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 20 deletions

View file

@ -166,10 +166,9 @@ func (c *codegen) emitStoreStructField(i int) {
// according to current scope. // according to current scope.
func (c *codegen) getVarIndex(name string) (varType, int) { func (c *codegen) getVarIndex(name string) (varType, int) {
if c.scope != nil { if c.scope != nil {
if i, ok := c.scope.arguments[name]; ok { vt, val := c.scope.vars.getVarIndex(name)
return varArgument, i if val >= 0 {
} else if i, ok := c.scope.locals[name]; ok { return vt, val
return varLocal, i
} }
} }
if i, ok := c.globals[name]; ok { if i, ok := c.globals[name]; ok {
@ -311,6 +310,9 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) {
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)}) emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
} }
f.vars.newScope()
defer f.vars.dropScope()
// We need to handle methods, which in Go, is just syntactic sugar. // We need to handle methods, which in Go, is just syntactic sugar.
// The method receiver will be passed in as first argument. // The method receiver will be passed in as first argument.
// We check if this declaration has a receiver and load it into scope. // We check if this declaration has a receiver and load it into scope.
@ -497,6 +499,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.IfStmt: case *ast.IfStmt:
c.scope.vars.newScope()
defer c.scope.vars.dropScope()
lIf := c.newLabel() lIf := c.newLabel()
lElse := c.newLabel() lElse := c.newLabel()
lElseEnd := c.newLabel() lElseEnd := c.newLabel()
@ -551,6 +556,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
} }
c.scope.vars.newScope()
c.setLabel(lStart) c.setLabel(lStart)
last := len(cc.Body) - 1 last := len(cc.Body) - 1
for j, stmt := range cc.Body { for j, stmt := range cc.Body {
@ -562,6 +569,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
emit.Jmp(c.prog.BinWriter, opcode.JMPL, switchEnd) emit.Jmp(c.prog.BinWriter, opcode.JMPL, switchEnd)
c.setLabel(lEnd) c.setLabel(lEnd)
c.scope.vars.dropScope()
} }
c.setLabel(switchEnd) c.setLabel(switchEnd)
@ -881,7 +890,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.BlockStmt:
c.scope.vars.newScope()
defer c.scope.vars.dropScope()
for i := range n.List {
ast.Walk(c, n.List[i])
}
return nil
case *ast.ForStmt: case *ast.ForStmt:
c.scope.vars.newScope()
defer c.scope.vars.dropScope()
fstart, label := c.generateLabel(labelStart) fstart, label := c.generateLabel(labelStart)
fend := c.newNamedLabel(labelEnd, label) fend := c.newNamedLabel(labelEnd, label)
fpost := c.newNamedLabel(labelPost, label) fpost := c.newNamedLabel(labelPost, label)
@ -924,6 +946,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.RangeStmt: case *ast.RangeStmt:
c.scope.vars.newScope()
defer c.scope.vars.dropScope()
start, label := c.generateLabel(labelStart) start, label := c.generateLabel(labelStart)
end := c.newNamedLabel(labelEnd, label) end := c.newNamedLabel(labelEnd, label)
post := c.newNamedLabel(labelPost, label) post := c.newNamedLabel(labelPost, label)

View file

@ -31,8 +31,7 @@ type funcScope struct {
variables []string variables []string
// Local variables // Local variables
locals map[string]int vars varScope
arguments map[string]int
// voidCalls are basically functions that return their value // voidCalls are basically functions that return their value
// into nothing. The stack has their return value but there // into nothing. The stack has their return value but there
@ -55,8 +54,7 @@ func newFuncScope(decl *ast.FuncDecl, label uint16) *funcScope {
name: name, name: name,
decl: decl, decl: decl,
label: label, label: label,
locals: map[string]int{}, vars: newVarScope(),
arguments: map[string]int{},
voidCalls: map[*ast.CallExpr]bool{}, voidCalls: map[*ast.CallExpr]bool{},
variables: []string{}, variables: []string{},
i: -1, i: -1,
@ -139,18 +137,7 @@ func (c *funcScope) stackSize() int64 {
// newVariable creates a new local variable or argument in the scope of the function. // newVariable creates a new local variable or argument in the scope of the function.
func (c *funcScope) newVariable(t varType, name string) int { func (c *funcScope) newVariable(t varType, name string) int {
var n int return c.vars.newVariable(t, name)
switch t {
case varLocal:
n = len(c.locals)
c.locals[name] = n
case varArgument:
n = len(c.arguments)
c.arguments[name] = n
default:
panic("invalid type")
}
return n
} }
// newLocal creates a new local variable into the scope of the function. // newLocal creates a new local variable into the scope of the function.

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
) )
@ -55,3 +56,52 @@ func TestMultiDeclarationLocalCompound(t *testing.T) {
}` }`
eval(t, src, big.NewInt(6)) eval(t, src, big.NewInt(6))
} }
func TestShadow(t *testing.T) {
srcTmpl := `package foo
func Main() int {
x := 1
y := 10
%s
x += 1 // increase old local
x := 30 // introduce new local
y += x // make sure is means something
}
return x+y
}`
runCase := func(b string) func(t *testing.T) {
return func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, b)
eval(t, src, big.NewInt(42))
}
}
t.Run("If", runCase("if true {"))
t.Run("For", runCase("for i := 0; i < 1; i++ {"))
t.Run("Range", runCase("for range []int{1} {"))
t.Run("Switch", runCase("switch true {\ncase false: x += 2\ncase true:"))
t.Run("Block", runCase("{"))
}
func TestArgumentLocal(t *testing.T) {
srcTmpl := `package foo
func some(a int) int {
if a > 42 {
a := 24
_ = a
}
return a
}
func Main() int {
return some(%d)
}`
t.Run("Override", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, 50)
eval(t, src, big.NewInt(50))
})
t.Run("NoOverride", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, 40)
eval(t, src, big.NewInt(40))
})
}

63
pkg/compiler/vars.go Normal file
View file

@ -0,0 +1,63 @@
package compiler
type varScope struct {
localsCnt int
argCnt int
arguments map[string]int
locals []map[string]int
}
func newVarScope() varScope {
return varScope{
arguments: make(map[string]int),
}
}
func (c *varScope) newScope() {
c.locals = append(c.locals, map[string]int{})
}
func (c *varScope) dropScope() {
c.locals = c.locals[:len(c.locals)-1]
}
func (c *varScope) getVarIndex(name string) (varType, int) {
for i := len(c.locals) - 1; i >= 0; i-- {
if i, ok := c.locals[i][name]; ok {
return varLocal, i
}
}
if i, ok := c.arguments[name]; ok {
return varArgument, i
}
return 0, -1
}
// newVariable creates a new local variable or argument in the scope of the function.
func (c *varScope) newVariable(t varType, name string) int {
var n int
switch t {
case varLocal:
return c.newLocal(name)
case varArgument:
_, ok := c.arguments[name]
if ok {
panic("argument is already allocated")
}
n = len(c.arguments)
c.arguments[name] = n
default:
panic("invalid type")
}
return n
}
// newLocal creates a new local variable in the current scope.
func (c *varScope) newLocal(name string) int {
idx := len(c.locals) - 1
m := c.locals[idx]
m[name] = c.localsCnt
c.localsCnt++
c.locals[idx] = m
return c.localsCnt - 1
}