oracle: reprocess request on fail

This commit is contained in:
Evgenii Stratonikov 2020-10-09 10:44:31 +03:00 committed by Evgeniy Stratonikov
parent aa852aaaac
commit e4528e59dc
6 changed files with 134 additions and 19 deletions

View file

@ -7,6 +7,8 @@ type OracleConfiguration struct {
Enabled bool `yaml:"Enabled"`
AllowPrivateHost bool `yaml:"AllowPrivateHost"`
Nodes []string `yaml:"Nodes"`
MaxTaskTimeout time.Duration `yaml:"MaxTaskTimeout"`
RefreshInterval time.Duration `yaml:"RefreshInterval"`
MaxConcurrentRequests int `yaml:"MaxConcurrentRequests"`
RequestTimeout time.Duration `yaml:"RequestTimeout"`
ResponseTimeout time.Duration `yaml:"ResponseTimeout"`

View file

@ -33,6 +33,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config
Log: zaptest.NewLogger(t),
Network: netmode.UnitTestNet,
MainCfg: config.OracleConfiguration{
RefreshInterval: time.Second,
UnlockWallet: config.Wallet{
Path: path.Join(oracleModulePath, w),
Password: pass,

View file

@ -40,6 +40,8 @@ type (
// respMtx protects responses map.
respMtx sync.RWMutex
responses map[uint64]*incompleteTx
// removed contains ids of requests which won't be processed further due to expiration.
removed map[uint64]bool
wallet *wallet.Wallet
}
@ -82,6 +84,12 @@ type (
const (
// defaultRequestTimeout is default request timeout.
defaultRequestTimeout = time.Second * 5
// defaultMaxTaskTimeout is default timeout for the request to be dropped if it can't be processed.
defaultMaxTaskTimeout = time.Hour
// defaultRefreshInterval is default timeout for the failed request to be reprocessed.
defaultRefreshInterval = time.Minute * 3
)
// NewOracle returns new oracle instance.
@ -92,6 +100,7 @@ func NewOracle(cfg Config) (*Oracle, error) {
close: make(chan struct{}),
requestMap: make(chan map[uint64]*state.OracleRequest, 1),
responses: make(map[uint64]*incompleteTx),
removed: make(map[uint64]bool),
}
if o.MainCfg.RequestTimeout == 0 {
o.MainCfg.RequestTimeout = defaultRequestTimeout
@ -100,6 +109,12 @@ func NewOracle(cfg Config) (*Oracle, error) {
o.MainCfg.MaxConcurrentRequests = defaultMaxConcurrentRequests
}
o.requestCh = make(chan request, o.MainCfg.MaxConcurrentRequests)
if o.MainCfg.MaxTaskTimeout == 0 {
o.MainCfg.MaxTaskTimeout = defaultMaxTaskTimeout
}
if o.MainCfg.RefreshInterval == 0 {
o.MainCfg.RefreshInterval = defaultRefreshInterval
}
var err error
w := cfg.MainCfg.UnlockWallet
@ -147,10 +162,35 @@ func (o *Oracle) Run() {
for i := 0; i < o.MainCfg.MaxConcurrentRequests; i++ {
go o.runRequestWorker()
}
tick := time.NewTicker(o.MainCfg.RefreshInterval)
for {
select {
case <-o.close:
tick.Stop()
return
case <-tick.C:
var reprocess []uint64
o.respMtx.RLock()
o.removed = make(map[uint64]bool)
for id, incTx := range o.responses {
incTx.RLock()
since := time.Since(incTx.time)
if since > o.MainCfg.MaxTaskTimeout {
o.removed[id] = true
} else if since > o.MainCfg.RefreshInterval {
reprocess = append(reprocess, id)
}
incTx.RUnlock()
}
for id := range o.removed {
delete(o.responses, id)
}
o.respMtx.Unlock()
for _, id := range reprocess {
o.requestCh <- request{ID: id}
}
case reqs := <-o.requestMap:
for id, req := range reqs {
o.requestCh <- request{

View file

@ -4,6 +4,7 @@ import (
"errors"
"net/http"
"net/url"
"time"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
@ -29,7 +30,7 @@ func (o *Oracle) runRequestWorker() {
if acc == nil {
continue
}
err := o.processRequest(acc.PrivateKey(), req.ID, req.Req)
err := o.processRequest(acc.PrivateKey(), req)
if err != nil {
o.Log.Debug("can't process request", zap.Uint64("id", req.ID), zap.Error(err))
}
@ -75,23 +76,32 @@ func (o *Oracle) ProcessRequestsInternal(reqs map[uint64]*state.OracleRequest) {
}
// Process actual requests.
for id := range reqs {
if err := o.processRequest(acc.PrivateKey(), id, reqs[id]); err != nil {
for id, req := range reqs {
if err := o.processRequest(acc.PrivateKey(), request{ID: id, Req: req}); err != nil {
o.Log.Debug("can't process request", zap.Error(err))
}
}
}
func (o *Oracle) processRequest(priv *keys.PrivateKey, id uint64, req *state.OracleRequest) error {
resp := &transaction.OracleResponse{ID: id}
u, err := url.ParseRequestURI(req.URL)
func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
if req.Req == nil {
o.processFailedRequest(priv, req)
return nil
}
incTx := o.getResponse(req.ID, true)
if incTx == nil {
return nil
}
resp := &transaction.OracleResponse{ID: req.ID}
u, err := url.ParseRequestURI(req.Req.URL)
if err == nil && !o.MainCfg.AllowPrivateHost {
err = o.URIValidator(u)
}
if err != nil {
resp.Code = transaction.Forbidden
} else if u.Scheme == "http" {
r, err := o.Client.Get(req.URL)
r, err := o.Client.Get(req.Req.URL)
switch {
case err != nil:
resp.Code = transaction.Error
@ -119,7 +129,7 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, id uint64, req *state.Ora
}
currentHeight := o.Chain.BlockHeight()
_, h, err := o.Chain.GetTransaction(req.OriginalTxID)
_, h, err := o.Chain.GetTransaction(req.Req.OriginalTxID)
if err != nil {
if !errors.Is(err, storage.ErrKeyNotFound) {
return err
@ -127,20 +137,20 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, id uint64, req *state.Ora
// The only reason tx can be not found is if it wasn't yet persisted from DAO.
h = currentHeight
}
tx, err := o.CreateResponseTx(int64(req.GasForResponse), h, resp)
tx, err := o.CreateResponseTx(int64(req.Req.GasForResponse), h, resp)
if err != nil {
return err
}
backupTx, err := o.CreateResponseTx(int64(req.GasForResponse), h, &transaction.OracleResponse{
ID: id,
backupTx, err := o.CreateResponseTx(int64(req.Req.GasForResponse), h, &transaction.OracleResponse{
ID: req.ID,
Code: transaction.ConsensusUnreachable,
})
if err != nil {
return err
}
incTx := o.getResponse(id)
incTx.Lock()
incTx.request = req.Req
incTx.tx = tx
incTx.backupTx = backupTx
incTx.reverifyTx()
@ -151,11 +161,13 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, id uint64, req *state.Ora
backupSig := priv.Sign(backupTx.GetSignedPart())
incTx.addResponse(priv.PublicKey(), backupSig, true)
readyTx, ready := incTx.finalize(o.getOracleNodes())
readyTx, ready := incTx.finalize(o.getOracleNodes(), false)
if ready {
ready = !incTx.isSent
incTx.isSent = true
}
incTx.time = time.Now()
incTx.attempts++
incTx.Unlock()
o.getBroadcaster().SendResponse(priv, resp, txSig)
@ -164,3 +176,33 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, id uint64, req *state.Ora
}
return nil
}
func (o *Oracle) processFailedRequest(priv *keys.PrivateKey, req request) {
// Request is being processed again.
incTx := o.getResponse(req.ID, false)
if incTx == nil {
// Request was processed by other oracle nodes.
return
} else if incTx.isSent {
// Tx was sent but not yet persisted. Try to pool it again.
o.getOnTransaction()(incTx.tx)
return
}
// Don't process request again, fallback to backup tx.
incTx.Lock()
readyTx, ready := incTx.finalize(o.getOracleNodes(), true)
if ready {
ready = !incTx.isSent
incTx.isSent = true
}
incTx.time = time.Now()
incTx.attempts++
txSig := incTx.backupSigs[string(priv.PublicKey().Bytes())].sig
incTx.Unlock()
o.getBroadcaster().SendResponse(priv, getFailedResponse(req.ID), txSig)
if ready {
o.getOnTransaction()(readyTx)
}
}

View file

@ -17,11 +17,11 @@ import (
"go.uber.org/zap"
)
func (o *Oracle) getResponse(reqID uint64) *incompleteTx {
func (o *Oracle) getResponse(reqID uint64, create bool) *incompleteTx {
o.respMtx.Lock()
defer o.respMtx.Unlock()
incTx, ok := o.responses[reqID]
if !ok {
if !ok && create && !o.removed[reqID] {
incTx = newIncompleteTx()
o.responses[reqID] = incTx
}
@ -31,7 +31,10 @@ func (o *Oracle) getResponse(reqID uint64) *incompleteTx {
// AddResponse processes oracle response from node pub.
// sig is response transaction signature.
func (o *Oracle) AddResponse(pub *keys.PublicKey, reqID uint64, txSig []byte) {
incTx := o.getResponse(reqID)
incTx := o.getResponse(reqID, true)
if incTx == nil {
return
}
incTx.Lock()
isBackup := false
@ -49,7 +52,7 @@ func (o *Oracle) AddResponse(pub *keys.PublicKey, reqID uint64, txSig []byte) {
}
}
incTx.addResponse(pub, txSig, isBackup)
readyTx, ready := incTx.finalize(o.getOracleNodes())
readyTx, ready := incTx.finalize(o.getOracleNodes(), false)
if ready {
ready = !incTx.isSent
incTx.isSent = true
@ -151,3 +154,10 @@ func isVerifyOk(v *vm.VM) bool {
ok, err := v.Estack().Pop().Item().TryBool()
return err == nil && ok
}
func getFailedResponse(id uint64) *transaction.OracleResponse {
return &transaction.OracleResponse{
ID: id,
Code: transaction.Error,
}
}

View file

@ -2,7 +2,9 @@ package oracle
import (
"sync"
"time"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io"
@ -15,6 +17,12 @@ type (
sync.RWMutex
// isSent is true tx was already broadcasted.
isSent bool
// attempts is how many times request was processed.
attempts int
// time is the time when request was last processed.
time time.Time
// request is oracle request.
request *state.OracleRequest
// tx is oracle response transaction.
tx *transaction.Transaction
// sigs contains signature from every oracle node.
@ -74,8 +82,8 @@ func (t *incompleteTx) addResponse(pub *keys.PublicKey, sig []byte, isBackup boo
// finalize checks is either main or backup tx has sufficient number of signatures and returns
// tx and bool value indicating if it is ready to be broadcasted.
func (t *incompleteTx) finalize(oracleNodes keys.PublicKeys) (*transaction.Transaction, bool) {
if finalizeTx(oracleNodes, t.tx, t.sigs) {
func (t *incompleteTx) finalize(oracleNodes keys.PublicKeys, backupOnly bool) (*transaction.Transaction, bool) {
if !backupOnly && finalizeTx(oracleNodes, t.tx, t.sigs) {
return t.tx, true
}
return t.backupTx, finalizeTx(oracleNodes, t.backupTx, t.backupSigs)
@ -107,3 +115,15 @@ func finalizeTx(oracleNodes keys.PublicKeys, tx *transaction.Transaction, txSigs
tx.Scripts[1].InvocationScript = w.Bytes()
return true
}
func (t *incompleteTx) getRequest() *state.OracleRequest {
t.RLock()
defer t.RUnlock()
return t.request
}
func (t *incompleteTx) getTime() time.Time {
t.RLock()
defer t.RUnlock()
return t.time
}