rpc: move session maintenance related code out of the result.Invoke

It's server who should be responsible for iterator ID creation and
iterator registration.
This commit is contained in:
Anna Shaleva 2022-07-07 22:03:11 +03:00
parent 4581cc386b
commit 8f73ce08c8
2 changed files with 126 additions and 133 deletions

View file

@ -10,8 +10,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
)
@ -30,21 +28,11 @@ type Invoke struct {
maxIteratorResultItems int
Session uuid.UUID
finalize func()
onNewSession OnNewSession
// invocationParams is non-nil iff MPT-based iterator sessions are supported.
invocationParams *InvocationParams
registerIterator RegisterIterator
}
type OnNewSession func(sessionID string, iterators []IteratorIdentifier, params *InvocationParams, finalize func())
// InvocationParams is a set of parameters used for invoke* calls.
type InvocationParams struct {
Trigger trigger.Type
Script []byte
ContractScriptHash util.Uint160
Transaction *transaction.Transaction
NextBlockHeight uint32
}
// RegisterIterator is a callback used to register new iterator on the server side.
type RegisterIterator func(sessionID string, item stackitem.Item, id int, finalize func()) uuid.UUID
// InvokeDiag is an additional diagnostic data for invocation.
type InvokeDiag struct {
@ -53,7 +41,7 @@ type InvokeDiag struct {
}
// NewInvoke returns a new Invoke structure with the given fields set.
func NewInvoke(ic *interop.Context, script []byte, faultException string, registerSession OnNewSession, maxIteratorResultItems int, params *InvocationParams) *Invoke {
func NewInvoke(ic *interop.Context, script []byte, faultException string, registerIterator RegisterIterator, maxIteratorResultItems int) *Invoke {
var diag *InvokeDiag
tree := ic.VM.GetInvocationTree()
if tree != nil {
@ -75,9 +63,8 @@ func NewInvoke(ic *interop.Context, script []byte, faultException string, regist
Notifications: notifications,
Diagnostics: diag,
finalize: ic.Finalize,
onNewSession: registerSession,
maxIteratorResultItems: maxIteratorResultItems,
invocationParams: params,
registerIterator: registerIterator,
}
}
@ -119,15 +106,6 @@ type Iterator struct {
Truncated bool
}
// IteratorIdentifier represents Iterator identifier on the server side. It is not for Client usage.
type IteratorIdentifier struct {
ID string
// Item represents Iterator stackitem. It is nil if SessionBackedByMPT is set to true.
Item stackitem.Item
// StackIndex represents Iterator stackitem index on the stack. It is valid iff Item is nil.
StackIndex int
}
// Finalize releases resources occupied by Iterators created at the script invocation.
// This method will be called automatically on Invoke marshalling or by the Server's
// sessions handler.
@ -144,9 +122,8 @@ func (r Invoke) MarshalJSON() ([]byte, error) {
err error
faultSep string
arr = make([]json.RawMessage, len(r.Stack))
sessionsEnabled = r.onNewSession != nil
sessionsEnabled = r.registerIterator != nil
sessionID string
iterators []IteratorIdentifier
)
if len(r.FaultException) != 0 {
faultSep = " / "
@ -156,23 +133,19 @@ arrloop:
var data []byte
if (r.Stack[i].Type() == stackitem.InteropT) && iterator.IsIterator(r.Stack[i]) {
if sessionsEnabled {
iteratorID := uuid.NewString()
if sessionID == "" {
sessionID = uuid.NewString()
}
iteratorID := r.registerIterator(sessionID, r.Stack[i], i, r.finalize)
data, err = json.Marshal(iteratorAux{
Type: stackitem.InteropT.String(),
Interface: iteratorInterfaceName,
ID: iteratorID,
ID: iteratorID.String(),
})
if err != nil {
r.FaultException += fmt.Sprintf("%sjson error: failed to marshal iterator: %v", faultSep, err)
break
}
ident := IteratorIdentifier{ID: iteratorID}
if r.invocationParams == nil {
ident.Item = r.Stack[i]
} else {
ident.StackIndex = i
}
iterators = append(iterators, ident)
} else {
iteratorValues, truncated := iterator.ValuesTruncated(r.Stack[i], r.maxIteratorResultItems)
value := make([]json.RawMessage, len(iteratorValues))
@ -203,17 +176,8 @@ arrloop:
arr[i] = data
}
if sessionsEnabled && len(iterators) != 0 {
sessionID = uuid.NewString()
if r.invocationParams == nil {
r.onNewSession(sessionID, iterators, nil, r.Finalize)
} else {
// Call finalizer manually if MPT-based iterator sessions are enabled.
defer r.Finalize()
r.onNewSession(sessionID, iterators, r.invocationParams, nil)
}
} else {
// Call finalizer manually if iterators are disabled or there's no iterator on stack.
if !sessionsEnabled || sessionID == "" {
// Call finalizer manually if iterators are disabled or there's no unnested iterators on estack.
defer r.Finalize()
}
if err == nil {

View file

@ -18,6 +18,7 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core"
@ -91,16 +92,37 @@ type (
}
// session holds a set of iterators got after invoke* call with corresponding
// finalizer and session expiration time.
// finalizer and session expiration timer.
session struct {
params *result.InvocationParams
iteratorsLock sync.Mutex
iteratorIdentifiers []result.IteratorIdentifier
// iterators stores the set of Iterator stackitems for the current session got from MPT-backed storage.
// iterators is non-nil iff SessionBackedByMPT is enabled.
iterators []stackitem.Item
timer *time.Timer
finalize func()
// iteratorsLock protects iteratorIdentifiers of the current session.
iteratorsLock sync.Mutex
// iteratorIdentifiers stores the set of Iterator stackitems got either from original invocation
// or from historic MPT-based invocation. In the second case, iteratorIdentifiers are supposed
// to be filled during the first `traverseiterator` call using corresponding params.
iteratorIdentifiers []*iteratorIdentifier
// params stores invocation params for historic MPT-based iterator traversing. It is nil in case
// of default non-MPT-based sessions mechanism enabled.
params *invocationParams
timer *time.Timer
finalize func()
}
// iteratorIdentifier represents Iterator on the server side, holding iterator ID, Iterator stackitem
// and iterator index on stack.
iteratorIdentifier struct {
ID string
// Item represents Iterator stackitem. It is nil if SessionBackedByMPT is set to true and no `traverseiterator`
// call was called for the corresponding session.
Item stackitem.Item
// StackIndex represents Iterator stackitem index on the stack. It can be used only for SessionBackedByMPT configuration.
StackIndex int
}
// invocationParams is a set of parameters used for invoke* calls.
invocationParams struct {
Trigger trigger.Type
Script []byte
ContractScriptHash util.Uint160
Transaction *transaction.Transaction
NextBlockHeight uint32
}
)
@ -1974,54 +1996,61 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
if err != nil {
faultException = err.Error()
}
var (
registerSession result.OnNewSession
params *result.InvocationParams
)
var registerIterator result.RegisterIterator
if s.config.SessionEnabled {
registerSession = s.registerSession
if s.config.SessionBackedByMPT {
params = &result.InvocationParams{
Trigger: t,
Script: script,
ContractScriptHash: contractScriptHash,
Transaction: tx,
NextBlockHeight: ic.Block.Index,
registerIterator = func(sessionID string, item stackitem.Item, stackIndex int, finalize func()) uuid.UUID {
iterID := uuid.New()
s.sessionsLock.Lock()
sess, ok := s.sessions[sessionID]
if !ok {
timer := time.AfterFunc(time.Second*time.Duration(s.config.SessionExpirationTime), func() {
s.sessionsLock.Lock()
defer s.sessionsLock.Unlock()
if len(s.sessions) == 0 {
return
}
sess, ok := s.sessions[sessionID]
if !ok {
return
}
sess.iteratorsLock.Lock()
if sess.finalize != nil {
sess.finalize()
}
delete(s.sessions, sessionID)
sess.iteratorsLock.Unlock()
})
sess = &session{
finalize: finalize,
timer: timer,
}
if s.config.SessionBackedByMPT {
sess.params = &invocationParams{
Trigger: t,
Script: script,
ContractScriptHash: contractScriptHash,
Transaction: tx,
NextBlockHeight: ic.Block.Index,
}
// Call finalizer manually if MPT-based iterator sessions are enabled. If disabled, then register finalizator.
if finalize != nil {
finalize()
sess.finalize = nil
}
item = nil
}
}
sess.iteratorIdentifiers = append(sess.iteratorIdentifiers, &iteratorIdentifier{
ID: iterID.String(),
Item: item,
StackIndex: stackIndex,
})
s.sessions[sessionID] = sess
s.sessionsLock.Unlock()
return iterID
}
}
return result.NewInvoke(ic, script, faultException, registerSession, s.config.MaxIteratorResultItems, params), nil
}
// registerSession is a callback used to add new iterator session to the sessions list.
// It performs no check whether sessions are enabled.
func (s *Server) registerSession(sessionID string, iterators []result.IteratorIdentifier, params *result.InvocationParams, finalize func()) {
s.sessionsLock.Lock()
timer := time.AfterFunc(time.Second*time.Duration(s.config.SessionExpirationTime), func() {
s.sessionsLock.Lock()
defer s.sessionsLock.Unlock()
if len(s.sessions) == 0 {
return
}
sess, ok := s.sessions[sessionID]
if !ok {
return
}
sess.iteratorsLock.Lock()
if sess.finalize != nil {
sess.finalize()
}
delete(s.sessions, sessionID)
sess.iteratorsLock.Unlock()
})
sess := &session{
params: params,
iteratorIdentifiers: iterators,
finalize: finalize,
timer: timer,
}
s.sessions[sessionID] = sess
s.sessionsLock.Unlock()
return result.NewInvoke(ic, script, faultException, registerIterator, s.config.MaxIteratorResultItems), nil
}
func (s *Server) traverseIterator(reqParams request.Params) (interface{}, *response.Error) {
@ -2064,40 +2093,40 @@ func (s *Server) traverseIterator(reqParams request.Params) (interface{}, *respo
iVals []stackitem.Item
respErr *response.Error
)
for i, it := range session.iteratorIdentifiers {
for _, it := range session.iteratorIdentifiers {
if iIDStr == it.ID {
if it.Item != nil { // If Iterator stackitem is there, then use it to retrieve iterator elements.
iVals = iterator.Values(it.Item, count)
} else { // Otherwise, use MPT-backed historic call to retrieve and traverse iterator.
if len(session.iterators) == 0 {
var (
b *block.Block
ic *interop.Context
)
b, err = s.getFakeNextBlock(session.params.NextBlockHeight)
if err != nil {
session.iteratorsLock.Unlock()
return nil, response.NewInternalServerError(fmt.Sprintf("unable to prepare block for historic call: %s", err))
}
ic, respErr = s.prepareInvocationContext(session.params.Trigger, session.params.Script, session.params.ContractScriptHash, session.params.Transaction, b, false)
if respErr != nil {
session.iteratorsLock.Unlock()
return nil, respErr
}
_ = ic.VM.Run() // No error check because FAULTed invocations could also contain iterator on stack.
stack := ic.VM.Estack().ToArray()
for _, itID := range session.iteratorIdentifiers {
j := itID.StackIndex
if (stack[j].Type() != stackitem.InteropT) || !iterator.IsIterator(stack[j]) {
session.iteratorsLock.Unlock()
return nil, response.NewInternalServerError(fmt.Sprintf("inconsistent historic call result: expected %s, got %s at stack position #%d", stackitem.InteropT, stack[j].Type(), j))
}
session.iterators = append(session.iterators, stack[j])
}
session.finalize = ic.Finalize
// If SessionBackedByMPT is enabled, then use MPT-backed historic call to retrieve and traverse iterator.
// Otherwise, iterator stackitem is ready and can be used.
if s.config.SessionBackedByMPT && it.Item == nil {
var (
b *block.Block
ic *interop.Context
)
b, err = s.getFakeNextBlock(session.params.NextBlockHeight)
if err != nil {
session.iteratorsLock.Unlock()
return nil, response.NewInternalServerError(fmt.Sprintf("unable to prepare block for historic call: %s", err))
}
iVals = iterator.Values(session.iterators[i], count)
ic, respErr = s.prepareInvocationContext(session.params.Trigger, session.params.Script, session.params.ContractScriptHash, session.params.Transaction, b, false)
if respErr != nil {
session.iteratorsLock.Unlock()
return nil, respErr
}
_ = ic.VM.Run() // No error check because FAULTed invocations could also contain iterator on stack.
stack := ic.VM.Estack().ToArray()
// Fill in the whole set of iterators for the current session in order not to repeat test invocation one more time for other session iterators.
for _, itID := range session.iteratorIdentifiers {
j := itID.StackIndex
if (stack[j].Type() != stackitem.InteropT) || !iterator.IsIterator(stack[j]) {
session.iteratorsLock.Unlock()
return nil, response.NewInternalServerError(fmt.Sprintf("inconsistent historic call result: expected %s, got %s at stack position #%d", stackitem.InteropT, stack[j].Type(), j))
}
session.iteratorIdentifiers[j].Item = stack[j]
}
session.finalize = ic.Finalize
}
iVals = iterator.Values(it.Item, count)
break
}
}