diff --git a/pkg/rpc/response/result/invoke.go b/pkg/rpc/response/result/invoke.go index c18ef9acf..438ba240e 100644 --- a/pkg/rpc/response/result/invoke.go +++ b/pkg/rpc/response/result/invoke.go @@ -10,8 +10,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" - "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" - "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -30,21 +28,11 @@ type Invoke struct { maxIteratorResultItems int Session uuid.UUID finalize func() - onNewSession OnNewSession - // invocationParams is non-nil iff MPT-based iterator sessions are supported. - invocationParams *InvocationParams + registerIterator RegisterIterator } -type OnNewSession func(sessionID string, iterators []IteratorIdentifier, params *InvocationParams, finalize func()) - -// InvocationParams is a set of parameters used for invoke* calls. -type InvocationParams struct { - Trigger trigger.Type - Script []byte - ContractScriptHash util.Uint160 - Transaction *transaction.Transaction - NextBlockHeight uint32 -} +// RegisterIterator is a callback used to register new iterator on the server side. +type RegisterIterator func(sessionID string, item stackitem.Item, id int, finalize func()) uuid.UUID // InvokeDiag is an additional diagnostic data for invocation. type InvokeDiag struct { @@ -53,7 +41,7 @@ type InvokeDiag struct { } // NewInvoke returns a new Invoke structure with the given fields set. -func NewInvoke(ic *interop.Context, script []byte, faultException string, registerSession OnNewSession, maxIteratorResultItems int, params *InvocationParams) *Invoke { +func NewInvoke(ic *interop.Context, script []byte, faultException string, registerIterator RegisterIterator, maxIteratorResultItems int) *Invoke { var diag *InvokeDiag tree := ic.VM.GetInvocationTree() if tree != nil { @@ -75,9 +63,8 @@ func NewInvoke(ic *interop.Context, script []byte, faultException string, regist Notifications: notifications, Diagnostics: diag, finalize: ic.Finalize, - onNewSession: registerSession, maxIteratorResultItems: maxIteratorResultItems, - invocationParams: params, + registerIterator: registerIterator, } } @@ -119,15 +106,6 @@ type Iterator struct { Truncated bool } -// IteratorIdentifier represents Iterator identifier on the server side. It is not for Client usage. -type IteratorIdentifier struct { - ID string - // Item represents Iterator stackitem. It is nil if SessionBackedByMPT is set to true. - Item stackitem.Item - // StackIndex represents Iterator stackitem index on the stack. It is valid iff Item is nil. - StackIndex int -} - // Finalize releases resources occupied by Iterators created at the script invocation. // This method will be called automatically on Invoke marshalling or by the Server's // sessions handler. @@ -144,9 +122,8 @@ func (r Invoke) MarshalJSON() ([]byte, error) { err error faultSep string arr = make([]json.RawMessage, len(r.Stack)) - sessionsEnabled = r.onNewSession != nil + sessionsEnabled = r.registerIterator != nil sessionID string - iterators []IteratorIdentifier ) if len(r.FaultException) != 0 { faultSep = " / " @@ -156,23 +133,19 @@ arrloop: var data []byte if (r.Stack[i].Type() == stackitem.InteropT) && iterator.IsIterator(r.Stack[i]) { if sessionsEnabled { - iteratorID := uuid.NewString() + if sessionID == "" { + sessionID = uuid.NewString() + } + iteratorID := r.registerIterator(sessionID, r.Stack[i], i, r.finalize) data, err = json.Marshal(iteratorAux{ Type: stackitem.InteropT.String(), Interface: iteratorInterfaceName, - ID: iteratorID, + ID: iteratorID.String(), }) if err != nil { r.FaultException += fmt.Sprintf("%sjson error: failed to marshal iterator: %v", faultSep, err) break } - ident := IteratorIdentifier{ID: iteratorID} - if r.invocationParams == nil { - ident.Item = r.Stack[i] - } else { - ident.StackIndex = i - } - iterators = append(iterators, ident) } else { iteratorValues, truncated := iterator.ValuesTruncated(r.Stack[i], r.maxIteratorResultItems) value := make([]json.RawMessage, len(iteratorValues)) @@ -203,17 +176,8 @@ arrloop: arr[i] = data } - if sessionsEnabled && len(iterators) != 0 { - sessionID = uuid.NewString() - if r.invocationParams == nil { - r.onNewSession(sessionID, iterators, nil, r.Finalize) - } else { - // Call finalizer manually if MPT-based iterator sessions are enabled. - defer r.Finalize() - r.onNewSession(sessionID, iterators, r.invocationParams, nil) - } - } else { - // Call finalizer manually if iterators are disabled or there's no iterator on stack. + if !sessionsEnabled || sessionID == "" { + // Call finalizer manually if iterators are disabled or there's no unnested iterators on estack. defer r.Finalize() } if err == nil { diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 19fbf666c..fd8d5e3c8 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core" @@ -91,16 +92,37 @@ type ( } // session holds a set of iterators got after invoke* call with corresponding - // finalizer and session expiration time. + // finalizer and session expiration timer. session struct { - params *result.InvocationParams - iteratorsLock sync.Mutex - iteratorIdentifiers []result.IteratorIdentifier - // iterators stores the set of Iterator stackitems for the current session got from MPT-backed storage. - // iterators is non-nil iff SessionBackedByMPT is enabled. - iterators []stackitem.Item - timer *time.Timer - finalize func() + // iteratorsLock protects iteratorIdentifiers of the current session. + iteratorsLock sync.Mutex + // iteratorIdentifiers stores the set of Iterator stackitems got either from original invocation + // or from historic MPT-based invocation. In the second case, iteratorIdentifiers are supposed + // to be filled during the first `traverseiterator` call using corresponding params. + iteratorIdentifiers []*iteratorIdentifier + // params stores invocation params for historic MPT-based iterator traversing. It is nil in case + // of default non-MPT-based sessions mechanism enabled. + params *invocationParams + timer *time.Timer + finalize func() + } + // iteratorIdentifier represents Iterator on the server side, holding iterator ID, Iterator stackitem + // and iterator index on stack. + iteratorIdentifier struct { + ID string + // Item represents Iterator stackitem. It is nil if SessionBackedByMPT is set to true and no `traverseiterator` + // call was called for the corresponding session. + Item stackitem.Item + // StackIndex represents Iterator stackitem index on the stack. It can be used only for SessionBackedByMPT configuration. + StackIndex int + } + // invocationParams is a set of parameters used for invoke* calls. + invocationParams struct { + Trigger trigger.Type + Script []byte + ContractScriptHash util.Uint160 + Transaction *transaction.Transaction + NextBlockHeight uint32 } ) @@ -1974,54 +1996,61 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash if err != nil { faultException = err.Error() } - var ( - registerSession result.OnNewSession - params *result.InvocationParams - ) + var registerIterator result.RegisterIterator if s.config.SessionEnabled { - registerSession = s.registerSession - if s.config.SessionBackedByMPT { - params = &result.InvocationParams{ - Trigger: t, - Script: script, - ContractScriptHash: contractScriptHash, - Transaction: tx, - NextBlockHeight: ic.Block.Index, + registerIterator = func(sessionID string, item stackitem.Item, stackIndex int, finalize func()) uuid.UUID { + iterID := uuid.New() + s.sessionsLock.Lock() + sess, ok := s.sessions[sessionID] + if !ok { + timer := time.AfterFunc(time.Second*time.Duration(s.config.SessionExpirationTime), func() { + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + if len(s.sessions) == 0 { + return + } + sess, ok := s.sessions[sessionID] + if !ok { + return + } + sess.iteratorsLock.Lock() + if sess.finalize != nil { + sess.finalize() + } + delete(s.sessions, sessionID) + sess.iteratorsLock.Unlock() + }) + sess = &session{ + finalize: finalize, + timer: timer, + } + if s.config.SessionBackedByMPT { + sess.params = &invocationParams{ + Trigger: t, + Script: script, + ContractScriptHash: contractScriptHash, + Transaction: tx, + NextBlockHeight: ic.Block.Index, + } + // Call finalizer manually if MPT-based iterator sessions are enabled. If disabled, then register finalizator. + if finalize != nil { + finalize() + sess.finalize = nil + } + item = nil + } } + sess.iteratorIdentifiers = append(sess.iteratorIdentifiers, &iteratorIdentifier{ + ID: iterID.String(), + Item: item, + StackIndex: stackIndex, + }) + s.sessions[sessionID] = sess + s.sessionsLock.Unlock() + return iterID } } - return result.NewInvoke(ic, script, faultException, registerSession, s.config.MaxIteratorResultItems, params), nil -} - -// registerSession is a callback used to add new iterator session to the sessions list. -// It performs no check whether sessions are enabled. -func (s *Server) registerSession(sessionID string, iterators []result.IteratorIdentifier, params *result.InvocationParams, finalize func()) { - s.sessionsLock.Lock() - timer := time.AfterFunc(time.Second*time.Duration(s.config.SessionExpirationTime), func() { - s.sessionsLock.Lock() - defer s.sessionsLock.Unlock() - if len(s.sessions) == 0 { - return - } - sess, ok := s.sessions[sessionID] - if !ok { - return - } - sess.iteratorsLock.Lock() - if sess.finalize != nil { - sess.finalize() - } - delete(s.sessions, sessionID) - sess.iteratorsLock.Unlock() - }) - sess := &session{ - params: params, - iteratorIdentifiers: iterators, - finalize: finalize, - timer: timer, - } - s.sessions[sessionID] = sess - s.sessionsLock.Unlock() + return result.NewInvoke(ic, script, faultException, registerIterator, s.config.MaxIteratorResultItems), nil } func (s *Server) traverseIterator(reqParams request.Params) (interface{}, *response.Error) { @@ -2064,40 +2093,40 @@ func (s *Server) traverseIterator(reqParams request.Params) (interface{}, *respo iVals []stackitem.Item respErr *response.Error ) - for i, it := range session.iteratorIdentifiers { + for _, it := range session.iteratorIdentifiers { if iIDStr == it.ID { - if it.Item != nil { // If Iterator stackitem is there, then use it to retrieve iterator elements. - iVals = iterator.Values(it.Item, count) - } else { // Otherwise, use MPT-backed historic call to retrieve and traverse iterator. - if len(session.iterators) == 0 { - var ( - b *block.Block - ic *interop.Context - ) - b, err = s.getFakeNextBlock(session.params.NextBlockHeight) - if err != nil { - session.iteratorsLock.Unlock() - return nil, response.NewInternalServerError(fmt.Sprintf("unable to prepare block for historic call: %s", err)) - } - ic, respErr = s.prepareInvocationContext(session.params.Trigger, session.params.Script, session.params.ContractScriptHash, session.params.Transaction, b, false) - if respErr != nil { - session.iteratorsLock.Unlock() - return nil, respErr - } - _ = ic.VM.Run() // No error check because FAULTed invocations could also contain iterator on stack. - stack := ic.VM.Estack().ToArray() - for _, itID := range session.iteratorIdentifiers { - j := itID.StackIndex - if (stack[j].Type() != stackitem.InteropT) || !iterator.IsIterator(stack[j]) { - session.iteratorsLock.Unlock() - return nil, response.NewInternalServerError(fmt.Sprintf("inconsistent historic call result: expected %s, got %s at stack position #%d", stackitem.InteropT, stack[j].Type(), j)) - } - session.iterators = append(session.iterators, stack[j]) - } - session.finalize = ic.Finalize + // If SessionBackedByMPT is enabled, then use MPT-backed historic call to retrieve and traverse iterator. + // Otherwise, iterator stackitem is ready and can be used. + if s.config.SessionBackedByMPT && it.Item == nil { + var ( + b *block.Block + ic *interop.Context + ) + b, err = s.getFakeNextBlock(session.params.NextBlockHeight) + if err != nil { + session.iteratorsLock.Unlock() + return nil, response.NewInternalServerError(fmt.Sprintf("unable to prepare block for historic call: %s", err)) } - iVals = iterator.Values(session.iterators[i], count) + ic, respErr = s.prepareInvocationContext(session.params.Trigger, session.params.Script, session.params.ContractScriptHash, session.params.Transaction, b, false) + if respErr != nil { + session.iteratorsLock.Unlock() + return nil, respErr + } + _ = ic.VM.Run() // No error check because FAULTed invocations could also contain iterator on stack. + stack := ic.VM.Estack().ToArray() + + // Fill in the whole set of iterators for the current session in order not to repeat test invocation one more time for other session iterators. + for _, itID := range session.iteratorIdentifiers { + j := itID.StackIndex + if (stack[j].Type() != stackitem.InteropT) || !iterator.IsIterator(stack[j]) { + session.iteratorsLock.Unlock() + return nil, response.NewInternalServerError(fmt.Sprintf("inconsistent historic call result: expected %s, got %s at stack position #%d", stackitem.InteropT, stack[j].Type(), j)) + } + session.iteratorIdentifiers[j].Item = stack[j] + } + session.finalize = ic.Finalize } + iVals = iterator.Values(it.Item, count) break } }