diff --git a/pkg/core/oracle_test.go b/pkg/core/oracle_test.go index f1a2148d0..32c37468e 100644 --- a/pkg/core/oracle_test.go +++ b/pkg/core/oracle_test.go @@ -9,6 +9,7 @@ import ( "net/url" "path" "strings" + "sync" "testing" "time" @@ -54,7 +55,7 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) ( m := make(map[uint64]*responseWithSig) ch := make(chan *transaction.Transaction, 5) orcCfg := getOracleConfig(t, bc, walletPath, pass) - orcCfg.ResponseHandler = saveToMapBroadcaster{m} + orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m} orcCfg.OnTransaction = saveTxToChan(ch) orcCfg.URIValidator = func(u *url.URL) error { if strings.HasPrefix(u.Host, "private") { @@ -298,18 +299,74 @@ func TestOracleFull(t *testing.T) { require.True(t, txes[0].HasAttribute(transaction.OracleResponseT)) } -type saveToMapBroadcaster struct { - m map[uint64]*responseWithSig +func TestNotYetRunningOracle(t *testing.T) { + bc := initTestChain(t, nil, nil) + acc, orc, _, _ := getTestOracle(t, bc, "./testdata/oracle2.json", "two") + mp := bc.GetMemPool() + orc.OnTransaction = func(tx *transaction.Transaction) { _ = mp.Add(tx, bc) } + bc.SetOracle(orc) + + cs := getOracleContractState(bc.contracts.Oracle.Hash, bc.contracts.Std.Hash) + require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs)) + + go bc.Run() + bc.setNodesByRole(t, true, noderoles.Oracle, keys.PublicKeys{acc.PrivateKey().PublicKey()}) + + var req state.OracleRequest + var reqs = make(map[uint64]*state.OracleRequest) + for i := uint64(0); i < 3; i++ { + reqs[i] = &req + } + orc.AddRequests(reqs) // 0, 1, 2 added to pending. + + var ids = []uint64{0, 1} + orc.RemoveRequests(ids) // 0, 1 removed from pending, 2 left. + + reqs = make(map[uint64]*state.OracleRequest) + for i := uint64(3); i < 5; i++ { + reqs[i] = &req + } + orc.AddRequests(reqs) // 3, 4 added to pending -> 2, 3, 4 in pending. + + ids = []uint64{3} + orc.RemoveRequests(ids) // 3 removed from pending -> 2, 4 in pending. + + go orc.Run() + t.Cleanup(orc.Shutdown) + + require.Eventually(t, func() bool { return mp.Count() == 2 }, + time.Second*3, time.Millisecond*200) + txes := mp.GetVerifiedTransactions() + require.Len(t, txes, 2) + var txids []uint64 + for _, tx := range txes { + for _, attr := range tx.Attributes { + if attr.Type == transaction.OracleResponseT { + resp := attr.Value.(*transaction.OracleResponse) + txids = append(txids, resp.ID) + } + } + } + require.Len(t, txids, 2) + require.Contains(t, txids, uint64(2)) + require.Contains(t, txids, uint64(4)) } -func (b saveToMapBroadcaster) SendResponse(_ *keys.PrivateKey, resp *transaction.OracleResponse, txSig []byte) { +type saveToMapBroadcaster struct { + mtx sync.RWMutex + m map[uint64]*responseWithSig +} + +func (b *saveToMapBroadcaster) SendResponse(_ *keys.PrivateKey, resp *transaction.OracleResponse, txSig []byte) { + b.mtx.Lock() + defer b.mtx.Unlock() b.m[resp.ID] = &responseWithSig{ resp: resp, txSig: txSig, } } -func (saveToMapBroadcaster) Run() {} -func (saveToMapBroadcaster) Shutdown() {} +func (*saveToMapBroadcaster) Run() {} +func (*saveToMapBroadcaster) Shutdown() {} type responseWithSig struct { resp *transaction.OracleResponse diff --git a/pkg/services/oracle/oracle.go b/pkg/services/oracle/oracle.go index 8369add20..0b0430d83 100644 --- a/pkg/services/oracle/oracle.go +++ b/pkg/services/oracle/oracle.go @@ -43,8 +43,13 @@ type ( requestCh chan request requestMap chan map[uint64]*state.OracleRequest - // respMtx protects responses map. - respMtx sync.RWMutex + // respMtx protects responses and pending maps. + respMtx sync.RWMutex + // running is false until Run() is invoked. + running bool + // pending contains requests for not yet started service. + pending map[uint64]*state.OracleRequest + // responses contains active not completely processed requests. responses map[uint64]*incompleteTx // removed contains ids of requests which won't be processed further due to expiration. removed map[uint64]bool @@ -102,6 +107,7 @@ func NewOracle(cfg Config) (*Oracle, error) { close: make(chan struct{}), requestMap: make(chan map[uint64]*state.OracleRequest, 1), + pending: make(map[uint64]*state.OracleRequest), responses: make(map[uint64]*incompleteTx), removed: make(map[uint64]bool), } @@ -165,7 +171,18 @@ func (o *Oracle) Shutdown() { // Run runs must be executed in a separate goroutine. func (o *Oracle) Run() { + o.respMtx.Lock() + if o.running { + o.respMtx.Unlock() + return + } o.Log.Info("starting oracle service") + + o.requestMap <- o.pending // Guaranteed to not block, only AddRequests sends to it. + o.pending = nil + o.running = true + o.respMtx.Unlock() + for i := 0; i < o.MainCfg.MaxConcurrentRequests; i++ { go o.runRequestWorker() } diff --git a/pkg/services/oracle/request.go b/pkg/services/oracle/request.go index 4c59cbbaa..483eb40e1 100644 --- a/pkg/services/oracle/request.go +++ b/pkg/services/oracle/request.go @@ -46,8 +46,14 @@ func (o *Oracle) runRequestWorker() { func (o *Oracle) RemoveRequests(ids []uint64) { o.respMtx.Lock() defer o.respMtx.Unlock() - for _, id := range ids { - delete(o.responses, id) + if !o.running { + for _, id := range ids { + delete(o.pending, id) + } + } else { + for _, id := range ids { + delete(o.responses, id) + } } } @@ -57,6 +63,16 @@ func (o *Oracle) AddRequests(reqs map[uint64]*state.OracleRequest) { return } + o.respMtx.Lock() + if !o.running { + for id, r := range reqs { + o.pending[id] = r + } + o.respMtx.Unlock() + return + } + o.respMtx.Unlock() + select { case o.requestMap <- reqs: default: @@ -172,6 +188,7 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { o.Log.Debug("oracle request processed", zap.String("url", req.Req.URL), zap.Int("code", int(resp.Code)), zap.String("result", string(resp.Result))) currentHeight := o.Chain.BlockHeight() + vubInc := o.Chain.GetConfig().MaxValidUntilBlockIncrement _, h, err := o.Chain.GetTransaction(req.Req.OriginalTxID) if err != nil { if !errors.Is(err, storage.ErrKeyNotFound) { @@ -180,10 +197,14 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { // The only reason tx can be not found is if it wasn't yet persisted from DAO. h = currentHeight } + h += vubInc // Main tx is only valid for RequestHeight + ValidUntilBlock. tx, err := o.CreateResponseTx(int64(req.Req.GasForResponse), h, resp) if err != nil { return err } + for h <= currentHeight { // Backup tx must be valid in any event. + h += vubInc + } backupTx, err := o.CreateResponseTx(int64(req.Req.GasForResponse), h, &transaction.OracleResponse{ ID: req.ID, Code: transaction.ConsensusUnreachable, diff --git a/pkg/services/oracle/response.go b/pkg/services/oracle/response.go index aef834c29..6b5f9d541 100644 --- a/pkg/services/oracle/response.go +++ b/pkg/services/oracle/response.go @@ -81,10 +81,10 @@ func readResponse(rc gio.ReadCloser, limit int) ([]byte, error) { } // CreateResponseTx creates unsigned oracle response transaction. -func (o *Oracle) CreateResponseTx(gasForResponse int64, height uint32, resp *transaction.OracleResponse) (*transaction.Transaction, error) { +func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transaction.OracleResponse) (*transaction.Transaction, error) { tx := transaction.New(o.oracleResponse, 0) tx.Nonce = uint32(resp.ID) - tx.ValidUntilBlock = height + o.Chain.GetConfig().MaxValidUntilBlockIncrement + tx.ValidUntilBlock = vub tx.Attributes = []transaction.Attribute{{ Type: transaction.OracleResponseT, Value: resp,