diff --git a/pkg/vm/stack/context.go b/pkg/vm/stack/context.go index 37e342739..343800afc 100644 --- a/pkg/vm/stack/context.go +++ b/pkg/vm/stack/context.go @@ -2,6 +2,7 @@ package stack import ( "encoding/binary" + "errors" ) // Context represent the current execution context of the VM. @@ -95,7 +96,8 @@ func (c *Context) String() string { return "execution context" } -func (c *Context) readUint32() uint32 { +// ReadUint32 reads a uint32 from the script +func (c *Context) ReadUint32() uint32 { start, end := c.IP(), c.IP()+4 if end > len(c.prog) { return 0 @@ -105,7 +107,8 @@ func (c *Context) readUint32() uint32 { return val } -func (c *Context) readUint16() uint16 { +// ReadUint16 reads a uint16 from the script +func (c *Context) ReadUint16() uint16 { start, end := c.IP(), c.IP()+2 if end > len(c.prog) { return 0 @@ -115,23 +118,33 @@ func (c *Context) readUint16() uint16 { return val } -func (c *Context) readByte() byte { - return c.readBytes(1)[0] +// ReadByte reads one byte from the script +func (c *Context) ReadByte() (byte, error) { + byt, err := c.ReadBytes(1) + if err != nil { + return 0, err + } + + return byt[0], nil } -func (c *Context) readBytes(n int) []byte { +// ReadBytes will read n bytes from the context +func (c *Context) ReadBytes(n int) ([]byte, error) { start, end := c.IP(), c.IP()+n if end > len(c.prog) { - return nil + return nil, errors.New("Too many bytes to read, pointer goes past end of program") } out := make([]byte, n) copy(out, c.prog[start:end]) c.ip += n - return out + return out, nil } -func (c *Context) readVarBytes() []byte { - n := c.readByte() - return c.readBytes(int(n)) +func (c *Context) readVarBytes() ([]byte, error) { + n, err := c.ReadByte() + if err != nil { + return nil, err + } + return c.ReadBytes(int(n)) }