diff --git a/pkg/neotest/basic.go b/pkg/neotest/basic.go index c789d7ae6..331167548 100644 --- a/pkg/neotest/basic.go +++ b/pkg/neotest/basic.go @@ -401,6 +401,10 @@ func TestInvoke(bc *core.Blockchain, tx *transaction.Transaction) (*vm.VM, error ttx := *tx ic, _ := bc.GetTestVM(trigger.Application, &ttx, b) + if isCoverageEnabled() { + ic.VM.SetOnExecHook(coverageHook()) + } + defer ic.Finalize() ic.VM.LoadWithFlags(tx.Script, callflag.All) diff --git a/pkg/neotest/compile.go b/pkg/neotest/compile.go index 70645c11c..c8288de3a 100644 --- a/pkg/neotest/compile.go +++ b/pkg/neotest/compile.go @@ -35,16 +35,21 @@ func CompileSource(t testing.TB, sender util.Uint160, src io.Reader, opts *compi m, err := compiler.CreateManifest(di, opts) require.NoError(t, err) - return &Contract{ + c := Contract{ Hash: state.CreateContractHash(sender, ne.Checksum, m.Name), NEF: ne, Manifest: m, } + + collectCoverage(t, di, c.Hash) + + return &c } // CompileFile compiles a contract from the file and returns its NEF, manifest and hash. func CompileFile(t testing.TB, sender util.Uint160, srcPath string, configPath string) *Contract { if c, ok := contracts[srcPath]; ok { + collectCoverage(t, rawCoverage[c.Hash].debugInfo, c.Hash) return c } @@ -77,6 +82,20 @@ func CompileFile(t testing.TB, sender util.Uint160, srcPath string, configPath s NEF: ne, Manifest: m, } + + collectCoverage(t, di, c.Hash) + contracts[srcPath] = c return c } + +func collectCoverage(t testing.TB, di *compiler.DebugInfo, h util.Uint160) { + if isCoverageEnabled() { + if _, ok := rawCoverage[h]; !ok { + rawCoverage[h] = &scriptRawCoverage{debugInfo: di} + } + t.Cleanup(func() { + reportCoverage() + }) + } +} diff --git a/pkg/neotest/coverage.go b/pkg/neotest/coverage.go new file mode 100644 index 000000000..e010d47ff --- /dev/null +++ b/pkg/neotest/coverage.go @@ -0,0 +1,154 @@ +package neotest + +import ( + "flag" + "fmt" + "io" + "os" + + "github.com/nspcc-dev/neo-go/pkg/compiler" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" +) + +var rawCoverage = make(map[util.Uint160]*scriptRawCoverage) + +var enabled bool +var coverProfile = "" + +type scriptRawCoverage struct { + debugInfo *compiler.DebugInfo + offsetsVisited []int +} + +type coverBlock struct { + startLine uint // Line number for block start. + startCol uint // Column number for block start. + endLine uint // Line number for block end. + endCol uint // Column number for block end. + stmts uint // Number of statements included in this block. + counts uint +} + +type documentName = string + +func isCoverageEnabled() bool { + if enabled { + return true + } + const coverProfileFlag = "test.coverprofile" + flag.VisitAll(func(f *flag.Flag) { + if f.Name == coverProfileFlag && f.Value != nil { + enabled = true + coverProfile = f.Value.String() + } + }) + if enabled { + // this is needed so go cover tool doesn't overwrite + // the file with our coverage when all tests are done + flag.Set(coverProfileFlag, "") + } + return enabled +} + +func coverageHook() vm.OnExecHook { + return func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) { + if cov, ok := rawCoverage[scriptHash]; ok { + cov.offsetsVisited = append(cov.offsetsVisited, offset) + } + } +} + +func reportCoverage() { + f, err := os.Create(coverProfile) + if err != nil { + panic(fmt.Sprintf("coverage: can't create file '%s' to write coverage report", coverProfile)) + } + defer f.Close() + writeCoverageReport(f) +} + +func writeCoverageReport(w io.Writer) { + fmt.Fprintf(w, "mode: set\n") + cover := processCover() + for name, blocks := range cover { + for _, b := range blocks { + c := 0 + if b.counts > 0 { + c = 1 + } + fmt.Fprintf(w, "%s:%d.%d,%d.%d %d %d\n", name, + b.startLine, b.startCol, + b.endLine, b.endCol, + b.stmts, + c, + ) + } + } +} + +func processCover() map[documentName][]coverBlock { + documents := make(map[documentName]struct{}) + for _, scriptRawCoverage := range rawCoverage { + for _, documentName := range scriptRawCoverage.debugInfo.Documents { + documents[documentName] = struct{}{} + } + } + + cover := make(map[documentName][]coverBlock) + + for documentName := range documents { + mappedBlocks := make(map[int]*coverBlock) + + for _, scriptRawCoverage := range rawCoverage { + di := scriptRawCoverage.debugInfo + documentSeqPoints := documentSeqPoints(di, documentName) + + for _, point := range documentSeqPoints { + b := coverBlock{ + startLine: uint(point.StartLine), + startCol: uint(point.StartCol), + endLine: uint(point.EndLine), + endCol: uint(point.EndCol), + stmts: 1 + uint(point.EndLine) - uint(point.StartLine), + counts: 0, + } + mappedBlocks[point.Opcode] = &b + } + } + + for _, scriptRawCoverage := range rawCoverage { + di := scriptRawCoverage.debugInfo + documentSeqPoints := documentSeqPoints(di, documentName) + + for _, offset := range scriptRawCoverage.offsetsVisited { + for _, point := range documentSeqPoints { + if point.Opcode == offset { + mappedBlocks[point.Opcode].counts++ + } + } + } + } + + var blocks []coverBlock + for _, b := range mappedBlocks { + blocks = append(blocks, *b) + } + cover[documentName] = blocks + } + + return cover +} + +func documentSeqPoints(di *compiler.DebugInfo, doc documentName) []compiler.DebugSeqPoint { + var res []compiler.DebugSeqPoint + for _, methodDebugInfo := range di.Methods { + for _, p := range methodDebugInfo.SeqPoints { + if di.Documents[p.Document] == doc { + res = append(res, p) + } + } + } + return res +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index d3959f1e2..549e9f6fc 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -64,6 +64,15 @@ const ( // SyscallHandler is a type for syscall handler. type SyscallHandler = func(*VM, uint32) error +// OnExecHook is a type for a callback that is invoked +// for each executed instruction +type OnExecHook = func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) + +// A struct that contains all VM hooks +type hooks struct { + onExec OnExecHook +} + // VM represents the virtual machine. type VM struct { state vmstate.State @@ -91,6 +100,10 @@ type VM struct { // invTree is a top-level invocation tree (if enabled). invTree *invocations.Tree + + // All registered hooks. + // Each hook should never be nil. + hooks hooks } var ( @@ -100,6 +113,10 @@ var ( bigTwo = big.NewInt(2) ) +var defaultHooks = hooks{ + onExec: func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) {}, +} + // New returns a new VM object ready to load AVM bytecode scripts. func New() *VM { return NewWithTrigger(trigger.Application) @@ -110,6 +127,7 @@ func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ state: vmstate.None, trigger: t, + hooks: defaultHooks, } vm.istack = make([]*Context, 0, 8) // Most of invocations use one-two contracts, but they're likely to have internal calls. @@ -117,6 +135,16 @@ func NewWithTrigger(t trigger.Type) *VM { return vm } +// SetOnExecHook sets the value of OnExecHook which +// will be invoked for each executed instruction. +// This function panics if the VM has been started. +func (v *VM) SetOnExecHook(hook OnExecHook) { + if v.state != vmstate.None { + panic("Cannot set onExec hook of a started VM") + } + v.hooks.onExec = hook +} + // SetPriceGetter registers the given PriceGetterFunc in v. // f accepts vm's Context, current instruction and instruction parameter. func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) { @@ -474,7 +502,9 @@ func (v *VM) Step() error { // step executes one instruction in the given context. func (v *VM) step(ctx *Context) error { + instruction_offset := v.Context().nextip op, param, err := ctx.Next() + v.hooks.onExec(v.GetCurrentScriptHash(), instruction_offset, op) if err != nil { v.state = vmstate.Fault return newError(ctx.ip, op, err)