diff --git a/pkg/io/binaryBufWriter.go b/pkg/io/binaryBufWriter.go index 31ba262ac..7015082e1 100644 --- a/pkg/io/binaryBufWriter.go +++ b/pkg/io/binaryBufWriter.go @@ -19,6 +19,11 @@ func NewBufBinWriter() *BufBinWriter { return &BufBinWriter{BinWriter: NewBinWriterFromIO(b), buf: b} } +// Len returns the number of bytes of the unread portion of the buffer. +func (bw *BufBinWriter) Len() int { + return bw.buf.Len() +} + // Bytes returns resulting buffer and makes future writes return an error. func (bw *BufBinWriter) Bytes() []byte { if bw.Err != nil { diff --git a/pkg/io/binaryWriter.go b/pkg/io/binaryWriter.go index 19086e255..d3d5503e1 100644 --- a/pkg/io/binaryWriter.go +++ b/pkg/io/binaryWriter.go @@ -93,6 +93,11 @@ func (w *BinWriter) WriteVarUint(val uint64) { } +// WriteVarBytes writes a variable byte into the underlying io.Writer without prefix. +func (w *BinWriter) WriteVarBytes(b []byte) { + w.WriteLE(b) +} + // WriteBytes writes a variable length byte array into the underlying io.Writer. func (w *BinWriter) WriteBytes(b []byte) { w.WriteVarUint(uint64(len(b))) diff --git a/pkg/io/binaryrw_test.go b/pkg/io/binaryrw_test.go index 672e14ffd..4cdfbac2d 100644 --- a/pkg/io/binaryrw_test.go +++ b/pkg/io/binaryrw_test.go @@ -53,6 +53,13 @@ func TestWriteBE(t *testing.T) { assert.Equal(t, val, readval) } +func TestBufBinWriter_Len(t *testing.T) { + val := []byte{0xde} + bw := NewBufBinWriter() + bw.WriteLE(val) + require.Equal(t, 1, bw.Len()) +} + func TestWriterErrHandling(t *testing.T) { var badio = &badRW{} bw := NewBinWriterFromIO(badio) diff --git a/pkg/vm/compiler/codegen.go b/pkg/vm/compiler/codegen.go index f54d0f844..b081784cb 100644 --- a/pkg/vm/compiler/codegen.go +++ b/pkg/vm/compiler/codegen.go @@ -1,7 +1,6 @@ package compiler import ( - "bytes" "encoding/binary" "go/ast" "go/constant" @@ -13,6 +12,7 @@ import ( "strings" "github.com/CityOfZion/neo-go/pkg/crypto" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/vm" ) @@ -24,7 +24,7 @@ type codegen struct { buildInfo *buildInfo // prog holds the output buffer. - prog *bytes.Buffer + prog *io.BufBinWriter // Type information. typeInfo *types.Info @@ -56,6 +56,10 @@ func (c *codegen) pc() int { } func (c *codegen) emitLoadConst(t types.TypeAndValue) { + if c.prog.Err != nil { + log.Fatal(c.prog.Err) + return + } switch typ := t.Type.Underlying().(type) { case *types.Basic: switch typ.Kind() { @@ -201,6 +205,10 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) { } func (c *codegen) Visit(node ast.Node) ast.Visitor { + if c.prog.Err != nil { + log.Fatal(c.prog.Err) + return nil + } switch n := node.(type) { // General declarations. @@ -761,11 +769,11 @@ func (c *codegen) newFunc(decl *ast.FuncDecl) *funcScope { } // CodeGen compiles the program to bytecode. -func CodeGen(info *buildInfo) (*bytes.Buffer, error) { +func CodeGen(info *buildInfo) ([]byte, error) { pkg := info.program.Package(info.initialPackage) c := &codegen{ buildInfo: info, - prog: new(bytes.Buffer), + prog: io.NewBufBinWriter(), l: []int{}, funcs: map[string]*funcScope{}, typeInfo: &pkg.Info, @@ -815,9 +823,12 @@ func CodeGen(info *buildInfo) (*bytes.Buffer, error) { } } - c.writeJumps() - - return c.prog, nil + if c.prog.Err != nil { + return nil, c.prog.Err + } + buf := c.prog.Bytes() + c.writeJumps(buf) + return buf, nil } func (c *codegen) resolveFuncDecls(f *ast.File) { @@ -831,8 +842,7 @@ func (c *codegen) resolveFuncDecls(f *ast.File) { } } -func (c *codegen) writeJumps() { - b := c.prog.Bytes() +func (c *codegen) writeJumps(b []byte) { for i, op := range b { j := i + 1 switch vm.Instruction(op) { diff --git a/pkg/vm/compiler/compiler.go b/pkg/vm/compiler/compiler.go index 5ead0d34a..cbc4a93be 100644 --- a/pkg/vm/compiler/compiler.go +++ b/pkg/vm/compiler/compiler.go @@ -60,7 +60,7 @@ func Compile(r io.Reader) ([]byte, error) { return nil, err } - return buf.Bytes(), nil + return buf, nil } type archive struct { diff --git a/pkg/vm/compiler/emit.go b/pkg/vm/compiler/emit.go index 92be632f8..296adea22 100644 --- a/pkg/vm/compiler/emit.go +++ b/pkg/vm/compiler/emit.go @@ -1,105 +1,114 @@ package compiler import ( - "bytes" "encoding/binary" "errors" "fmt" - "io" "math/big" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/vm" ) -func emit(w *bytes.Buffer, instr vm.Instruction, b []byte) error { - if err := w.WriteByte(byte(instr)); err != nil { - return err - } - _, err := w.Write(b) - return err +// emit a VM Instruction with data to the given buffer. +func emit(w *io.BufBinWriter, instr vm.Instruction, b []byte) { + w.WriteLE(byte(instr)) + w.WriteVarBytes(b) } -func emitOpcode(w io.ByteWriter, instr vm.Instruction) error { - return w.WriteByte(byte(instr)) +// emitOpcode emits a single VM Instruction the given buffer. +func emitOpcode(w *io.BufBinWriter, instr vm.Instruction) { + w.WriteLE(byte(instr)) } -func emitBool(w io.ByteWriter, ok bool) error { +// emitBool emits a bool type the given buffer. +func emitBool(w *io.BufBinWriter, ok bool) { if ok { - return emitOpcode(w, vm.PUSHT) + emitOpcode(w, vm.PUSHT) + return } - return emitOpcode(w, vm.PUSHF) + emitOpcode(w, vm.PUSHF) } -func emitInt(w *bytes.Buffer, i int64) error { - if i == -1 { - return emitOpcode(w, vm.PUSHM1) - } - if i == 0 { - return emitOpcode(w, vm.PUSHF) - } - if i > 0 && i < 16 { +// emitInt emits a int type to the given buffer. +func emitInt(w *io.BufBinWriter, i int64) { + switch { + case i == -1: + emitOpcode(w, vm.PUSHM1) + return + case i == 0: + emitOpcode(w, vm.PUSHF) + return + case i > 0 && i < 16: val := vm.Instruction(int(vm.PUSH1) - 1 + int(i)) - return emitOpcode(w, val) + emitOpcode(w, val) + return } bInt := big.NewInt(i) val := util.ArrayReverse(bInt.Bytes()) - return emitBytes(w, val) + emitBytes(w, val) } -func emitString(w *bytes.Buffer, s string) error { - return emitBytes(w, []byte(s)) +// emitString emits a string to the given buffer. +func emitString(w *io.BufBinWriter, s string) { + emitBytes(w, []byte(s)) } -func emitBytes(w *bytes.Buffer, b []byte) error { - var ( - err error - n = len(b) - ) +// emitBytes emits a byte array to the given buffer. +func emitBytes(w *io.BufBinWriter, b []byte) { + n := len(b) switch { case n <= int(vm.PUSHBYTES75): - return emit(w, vm.Instruction(n), b) + emit(w, vm.Instruction(n), b) + return case n < 0x100: - err = emit(w, vm.PUSHDATA1, []byte{byte(n)}) + emit(w, vm.PUSHDATA1, []byte{byte(n)}) case n < 0x10000: buf := make([]byte, 2) binary.LittleEndian.PutUint16(buf, uint16(n)) - err = emit(w, vm.PUSHDATA2, buf) + emit(w, vm.PUSHDATA2, buf) default: buf := make([]byte, 4) binary.LittleEndian.PutUint32(buf, uint32(n)) - err = emit(w, vm.PUSHDATA4, buf) + emit(w, vm.PUSHDATA4, buf) + if w.Err != nil { + return + } } - if err != nil { - return err - } - _, err = w.Write(b) - return err + + w.WriteBytes(b) } -func emitSyscall(w *bytes.Buffer, api string) error { +// emitSyscall emits the syscall API to the given buffer. +// Syscall API string cannot be 0. +func emitSyscall(w *io.BufBinWriter, api string) { if len(api) == 0 { - return errors.New("syscall api cannot be of length 0") + w.Err = errors.New("syscall api cannot be of length 0") + return } buf := make([]byte, len(api)+1) buf[0] = byte(len(api)) copy(buf[1:], api) - return emit(w, vm.SYSCALL, buf) + emit(w, vm.SYSCALL, buf) } -func emitCall(w *bytes.Buffer, instr vm.Instruction, label int16) error { - return emitJmp(w, instr, label) +// emitCall emits a call Instruction with label to the given buffer. +func emitCall(w *io.BufBinWriter, instr vm.Instruction, label int16) { + emitJmp(w, instr, label) } -func emitJmp(w *bytes.Buffer, instr vm.Instruction, label int16) error { +// emitJmp emits a jump Instruction along with label to the given buffer. +func emitJmp(w *io.BufBinWriter, instr vm.Instruction, label int16) { if !isInstrJmp(instr) { - return fmt.Errorf("opcode %s is not a jump or call type", instr) + w.Err = fmt.Errorf("opcode %s is not a jump or call type", instr) + return } buf := make([]byte, 2) binary.LittleEndian.PutUint16(buf, uint16(label)) - return emit(w, instr, buf) + emit(w, instr, buf) } func isInstrJmp(instr vm.Instruction) bool {