Merge pull request #2612 from nspcc-dev/fancy-service-restart

Fancy service restart
This commit is contained in:
Roman Khimov 2022-08-02 14:11:44 +03:00 committed by GitHub
commit cfd2a35172
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 526 additions and 146 deletions

View file

@ -147,11 +147,11 @@ func newTestChain(t *testing.T, f func(*config.Config), run bool) (*core.Blockch
Chain: chain, Chain: chain,
ProtocolConfiguration: chain.GetConfig(), ProtocolConfiguration: chain.GetConfig(),
RequestTx: netSrv.RequestTx, RequestTx: netSrv.RequestTx,
Wallet: serverConfig.Wallet, Wallet: &cfg.ApplicationConfiguration.UnlockWallet,
TimePerBlock: serverConfig.TimePerBlock, TimePerBlock: serverConfig.TimePerBlock,
}) })
require.NoError(t, err) require.NoError(t, err)
netSrv.AddExtensibleHPService(cons, consensus.Category, cons.OnPayload, cons.OnTransaction) netSrv.AddConsensusService(cons, cons.OnPayload, cons.OnTransaction)
go netSrv.Start(make(chan error, 1)) go netSrv.Start(make(chan error, 1))
errCh := make(chan error, 2) errCh := make(chan error, 2)
rpcServer := rpcsrv.New(chain, cfg.ApplicationConfiguration.RPC, netSrv, nil, logger, errCh) rpcServer := rpcsrv.New(chain, cfg.ApplicationConfiguration.RPC, netSrv, nil, logger, errCh)

View file

@ -8,10 +8,11 @@ import (
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"syscall" "time"
"github.com/nspcc-dev/neo-go/cli/options" "github.com/nspcc-dev/neo-go/cli/options"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
@ -385,14 +386,14 @@ func restoreDB(ctx *cli.Context) error {
return nil return nil
} }
func mkOracle(config network.ServerConfig, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*oracle.Oracle, error) { func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*oracle.Oracle, error) {
if !config.OracleCfg.Enabled { if !config.Enabled {
return nil, nil return nil, nil
} }
orcCfg := oracle.Config{ orcCfg := oracle.Config{
Log: log, Log: log,
Network: config.Net, Network: magic,
MainCfg: config.OracleCfg, MainCfg: config,
Chain: chain, Chain: chain,
OnTransaction: serv.RelayTxn, OnTransaction: serv.RelayTxn,
} }
@ -405,8 +406,8 @@ func mkOracle(config network.ServerConfig, chain *core.Blockchain, serv *network
return orc, nil return orc, nil
} }
func mkConsensus(config network.ServerConfig, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (consensus.Service, error) { func mkConsensus(config config.Wallet, tpb time.Duration, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (consensus.Service, error) {
if config.Wallet == nil { if len(config.Path) == 0 {
return nil, nil return nil, nil
} }
srv, err := consensus.NewService(consensus.Config{ srv, err := consensus.NewService(consensus.Config{
@ -415,26 +416,26 @@ func mkConsensus(config network.ServerConfig, chain *core.Blockchain, serv *netw
Chain: chain, Chain: chain,
ProtocolConfiguration: chain.GetConfig(), ProtocolConfiguration: chain.GetConfig(),
RequestTx: serv.RequestTx, RequestTx: serv.RequestTx,
Wallet: config.Wallet, Wallet: &config,
TimePerBlock: config.TimePerBlock, TimePerBlock: tpb,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("can't initialize Consensus module: %w", err) return nil, fmt.Errorf("can't initialize Consensus module: %w", err)
} }
serv.AddExtensibleHPService(srv, consensus.Category, srv.OnPayload, srv.OnTransaction) serv.AddConsensusService(srv, srv.OnPayload, srv.OnTransaction)
return srv, nil return srv, nil
} }
func mkP2PNotary(config network.ServerConfig, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*notary.Notary, error) { func mkP2PNotary(config config.P2PNotary, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*notary.Notary, error) {
if !config.P2PNotaryCfg.Enabled { if !config.Enabled {
return nil, nil return nil, nil
} }
if !chain.P2PSigExtensionsEnabled() { if !chain.P2PSigExtensionsEnabled() {
return nil, errors.New("P2PSigExtensions are disabled, but Notary service is enabled") return nil, errors.New("P2PSigExtensions are disabled, but Notary service is enabled")
} }
cfg := notary.Config{ cfg := notary.Config{
MainCfg: config.P2PNotaryCfg, MainCfg: config,
Chain: chain, Chain: chain,
Log: log, Log: log,
} }
@ -492,15 +493,15 @@ func startServer(ctx *cli.Context) error {
} }
serv.AddExtensibleService(sr, stateroot.Category, sr.OnPayload) serv.AddExtensibleService(sr, stateroot.Category, sr.OnPayload)
oracleSrv, err := mkOracle(serverConfig, chain, serv, log) oracleSrv, err := mkOracle(cfg.ApplicationConfiguration.Oracle, cfg.ProtocolConfiguration.Magic, chain, serv, log)
if err != nil { if err != nil {
return cli.NewExitError(err, 1) return cli.NewExitError(err, 1)
} }
_, err = mkConsensus(serverConfig, chain, serv, log) dbftSrv, err := mkConsensus(cfg.ApplicationConfiguration.UnlockWallet, serverConfig.TimePerBlock, chain, serv, log)
if err != nil { if err != nil {
return cli.NewExitError(err, 1) return cli.NewExitError(err, 1)
} }
_, err = mkP2PNotary(serverConfig, chain, serv, log) p2pNotary, err := mkP2PNotary(cfg.ApplicationConfiguration.P2PNotary, chain, serv, log)
if err != nil { if err != nil {
return cli.NewExitError(err, 1) return cli.NewExitError(err, 1)
} }
@ -513,8 +514,10 @@ func startServer(ctx *cli.Context) error {
rpcServer.Start() rpcServer.Start()
} }
sighupCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sighupCh, syscall.SIGHUP) signal.Notify(sigCh, sighup)
signal.Notify(sigCh, sigusr1)
signal.Notify(sigCh, sigusr2)
fmt.Fprintln(ctx.App.Writer, Logo()) fmt.Fprintln(ctx.App.Writer, Logo())
fmt.Fprintln(ctx.App.Writer, serv.UserAgent) fmt.Fprintln(ctx.App.Writer, serv.UserAgent)
@ -527,19 +530,97 @@ Main:
case err := <-errChan: case err := <-errChan:
shutdownErr = fmt.Errorf("server error: %w", err) shutdownErr = fmt.Errorf("server error: %w", err)
cancel() cancel()
case sig := <-sighupCh: case sig := <-sigCh:
log.Info("signal received", zap.Stringer("name", sig))
cfgnew, err := getConfigFromContext(ctx)
if err != nil {
log.Warn("can't reread the config file, signal ignored", zap.Error(err))
break // Continue working.
}
if !cfg.ProtocolConfiguration.Equals(&cfgnew.ProtocolConfiguration) {
log.Warn("ProtocolConfiguration changed, signal ignored")
break // Continue working.
}
if !cfg.ApplicationConfiguration.EqualsButServices(&cfgnew.ApplicationConfiguration) {
log.Warn("ApplicationConfiguration changed in incompatible way, signal ignored")
break // Continue working.
}
configureAddresses(&cfgnew.ApplicationConfiguration)
switch sig { switch sig {
case syscall.SIGHUP: case sighup:
log.Info("SIGHUP received, restarting rpc-server") serv.DelService(&rpcServer)
rpcServer.Shutdown() rpcServer.Shutdown()
rpcServer = rpcsrv.New(chain, cfg.ApplicationConfiguration.RPC, serv, oracleSrv, log, errChan) rpcServer = rpcsrv.New(chain, cfgnew.ApplicationConfiguration.RPC, serv, oracleSrv, log, errChan)
serv.AddService(&rpcServer) // Replaces old one by service name. serv.AddService(&rpcServer)
if !cfg.ApplicationConfiguration.RPC.StartWhenSynchronized || serv.IsInSync() { if !cfgnew.ApplicationConfiguration.RPC.StartWhenSynchronized || serv.IsInSync() {
rpcServer.Start() rpcServer.Start()
} }
pprof.ShutDown()
pprof = metrics.NewPprofService(cfgnew.ApplicationConfiguration.Pprof, log)
go pprof.Start()
prometheus.ShutDown()
prometheus = metrics.NewPrometheusService(cfgnew.ApplicationConfiguration.Prometheus, log)
go prometheus.Start()
case sigusr1:
if oracleSrv != nil {
serv.DelService(oracleSrv)
chain.SetOracle(nil)
rpcServer.SetOracleHandler(nil)
oracleSrv.Shutdown()
}
oracleSrv, err = mkOracle(cfgnew.ApplicationConfiguration.Oracle, cfgnew.ProtocolConfiguration.Magic, chain, serv, log)
if err != nil {
log.Error("failed to create oracle service", zap.Error(err))
break // Keep going.
}
if oracleSrv != nil {
rpcServer.SetOracleHandler(oracleSrv)
if serv.IsInSync() {
oracleSrv.Start()
}
}
if p2pNotary != nil {
serv.DelService(p2pNotary)
chain.SetNotary(nil)
p2pNotary.Shutdown()
}
p2pNotary, err = mkP2PNotary(cfgnew.ApplicationConfiguration.P2PNotary, chain, serv, log)
if err != nil {
log.Error("failed to create notary service", zap.Error(err))
break // Keep going.
}
if p2pNotary != nil && serv.IsInSync() {
p2pNotary.Start()
}
serv.DelExtensibleService(sr, stateroot.Category)
srMod.SetUpdateValidatorsCallback(nil)
sr.Shutdown()
sr, err = stateroot.New(cfgnew.ApplicationConfiguration.StateRoot, srMod, log, chain, serv.BroadcastExtensible)
if err != nil {
log.Error("failed to create state validation service", zap.Error(err))
break // The show must go on.
}
serv.AddExtensibleService(sr, stateroot.Category, sr.OnPayload)
if serv.IsInSync() {
sr.Start()
}
case sigusr2:
if dbftSrv != nil {
serv.DelConsensusService(dbftSrv)
dbftSrv.Shutdown()
}
dbftSrv, err = mkConsensus(cfgnew.ApplicationConfiguration.UnlockWallet, serverConfig.TimePerBlock, chain, serv, log)
if err != nil {
log.Error("failed to create consensus service", zap.Error(err))
break // Whatever happens, I'll leave it all to chance.
}
if dbftSrv != nil && serv.IsInSync() {
dbftSrv.Start()
}
} }
cfg = cfgnew
case <-grace.Done(): case <-grace.Done():
signal.Stop(sighupCh) signal.Stop(sigCh)
serv.Shutdown() serv.Shutdown()
break Main break Main
} }

View file

@ -0,0 +1,12 @@
//go:build !windows
// +build !windows
package server
import "syscall"
const (
sighup = syscall.SIGHUP
sigusr1 = syscall.SIGUSR1
sigusr2 = syscall.SIGUSR2
)

View file

@ -0,0 +1,13 @@
//go:build windows
// +build windows
package server
import "syscall"
const (
// Doesn't really matter, Windows can't do it.
sighup = syscall.SIGHUP
sigusr1 = syscall.Signal(0xa)
sigusr2 = syscall.Signal(0xc)
)

View file

@ -68,12 +68,26 @@ current height of the node.
### Restarting node services ### Restarting node services
To restart some node services without full node restart, send the SIGHUP On Unix-like platforms HUP, USR1 and USR2 signals can be used to control node
signal. List of the services to be restarted on SIGHUP receiving: services. Upon receiving any of these signals node rereads the configuration
file, checks for its compatibility (ProtocolConfiguration can't be changed and
ApplicationConfiguration can only be changed for services) and then
stops/starts services according to the old and new configurations. Services
are broadly split into three main categories:
* client-oriented
These provide some service to clients: RPC, Pprof and Prometheus
servers. They're controlled with the HUP signal.
* network-oriented
These provide some service to the network: Oracle, State validation and P2P
Notary. They're controlled with the USR1 signal.
* consensus
That's dBFT, it's a special one and it's controlled with USR2.
| Service | Action | Typical scenarios when this can be useful (without full node restart):
| --- | --- | * enabling some service
| RPC server | Restarting with the old configuration and updated TLS certificates | * changing RPC configuration
* updating TLS certificates for the RPC server
* resolving operational issues
### DB import/exports ### DB import/exports

View file

@ -29,3 +29,25 @@ type ApplicationConfiguration struct {
// ExtensiblePoolSize is the maximum amount of the extensible payloads from a single sender. // ExtensiblePoolSize is the maximum amount of the extensible payloads from a single sender.
ExtensiblePoolSize int `yaml:"ExtensiblePoolSize"` ExtensiblePoolSize int `yaml:"ExtensiblePoolSize"`
} }
// EqualsButServices returns true when the o is the same as a except for services
// (Oracle, P2PNotary, Pprof, Prometheus, RPC, StateRoot and UnlockWallet sections).
func (a *ApplicationConfiguration) EqualsButServices(o *ApplicationConfiguration) bool {
if a.Address != o.Address ||
a.AnnouncedNodePort != o.AnnouncedNodePort ||
a.AttemptConnPeers != o.AttemptConnPeers ||
a.DBConfiguration != o.DBConfiguration ||
a.DialTimeout != o.DialTimeout ||
a.ExtensiblePoolSize != o.ExtensiblePoolSize ||
a.LogPath != o.LogPath ||
a.MaxPeers != o.MaxPeers ||
a.MinPeers != o.MinPeers ||
a.NodePort != o.NodePort ||
a.PingInterval != o.PingInterval ||
a.PingTimeout != o.PingTimeout ||
a.ProtoTickInterval != o.ProtoTickInterval ||
a.Relay != o.Relay {
return false
}
return true
}

View file

@ -0,0 +1,22 @@
package config
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationConfigurationEquals(t *testing.T) {
a := &ApplicationConfiguration{}
o := &ApplicationConfiguration{}
require.True(t, a.EqualsButServices(o))
require.True(t, o.EqualsButServices(a))
require.True(t, a.EqualsButServices(a))
cfg1, err := LoadFile(filepath.Join("..", "..", "config", "protocol.mainnet.yml"))
require.NoError(t, err)
cfg2, err := LoadFile(filepath.Join("..", "..", "config", "protocol.testnet.yml"))
require.NoError(t, err)
require.False(t, cfg1.ApplicationConfiguration.EqualsButServices(&cfg2.ApplicationConfiguration))
}

View file

@ -196,3 +196,78 @@ func (p *ProtocolConfiguration) GetNumOfCNs(height uint32) int {
func (p *ProtocolConfiguration) ShouldUpdateCommitteeAt(height uint32) bool { func (p *ProtocolConfiguration) ShouldUpdateCommitteeAt(height uint32) bool {
return height%uint32(p.GetCommitteeSize(height)) == 0 return height%uint32(p.GetCommitteeSize(height)) == 0
} }
// Equals allows to compare two ProtocolConfiguration instances, returns true if
// they're equal.
func (p *ProtocolConfiguration) Equals(o *ProtocolConfiguration) bool {
if p.GarbageCollectionPeriod != o.GarbageCollectionPeriod ||
p.InitialGASSupply != o.InitialGASSupply ||
p.KeepOnlyLatestState != o.KeepOnlyLatestState ||
p.Magic != o.Magic ||
p.MaxBlockSize != o.MaxBlockSize ||
p.MaxBlockSystemFee != o.MaxBlockSystemFee ||
p.MaxTraceableBlocks != o.MaxTraceableBlocks ||
p.MaxTransactionsPerBlock != o.MaxTransactionsPerBlock ||
p.MaxValidUntilBlockIncrement != o.MaxValidUntilBlockIncrement ||
p.MemPoolSize != o.MemPoolSize ||
p.P2PNotaryRequestPayloadPoolSize != o.P2PNotaryRequestPayloadPoolSize ||
p.P2PSigExtensions != o.P2PSigExtensions ||
p.P2PStateExchangeExtensions != o.P2PStateExchangeExtensions ||
p.RemoveUntraceableBlocks != o.RemoveUntraceableBlocks ||
p.ReservedAttributes != o.ReservedAttributes ||
p.SaveStorageBatch != o.SaveStorageBatch ||
p.SecondsPerBlock != o.SecondsPerBlock ||
p.StateRootInHeader != o.StateRootInHeader ||
p.StateSyncInterval != o.StateSyncInterval ||
p.ValidatorsCount != o.ValidatorsCount ||
p.VerifyBlocks != o.VerifyBlocks ||
p.VerifyTransactions != o.VerifyTransactions ||
len(p.CommitteeHistory) != len(o.CommitteeHistory) ||
len(p.Hardforks) != len(o.Hardforks) ||
len(p.NativeUpdateHistories) != len(o.NativeUpdateHistories) ||
len(p.SeedList) != len(o.SeedList) ||
len(p.StandbyCommittee) != len(o.StandbyCommittee) ||
len(p.ValidatorsHistory) != len(o.ValidatorsHistory) {
return false
}
for k, v := range p.CommitteeHistory {
vo, ok := o.CommitteeHistory[k]
if !ok || v != vo {
return false
}
}
for k, v := range p.Hardforks {
vo, ok := o.Hardforks[k]
if !ok || v != vo {
return false
}
}
for k, v := range p.NativeUpdateHistories {
vo := o.NativeUpdateHistories[k]
if len(v) != len(vo) {
return false
}
for i := range v {
if v[i] != vo[i] {
return false
}
}
}
for i := range p.SeedList {
if p.SeedList[i] != o.SeedList[i] {
return false
}
}
for i := range p.StandbyCommittee {
if p.StandbyCommittee[i] != o.StandbyCommittee[i] {
return false
}
}
for k, v := range p.ValidatorsHistory {
vo, ok := o.ValidatorsHistory[k]
if !ok || v != vo {
return false
}
}
return true
}

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -143,3 +144,76 @@ func TestGetCommitteeAndCNs(t *testing.T) {
require.Equal(t, 4, p.GetNumOfCNs(200)) require.Equal(t, 4, p.GetNumOfCNs(200))
require.Equal(t, 4, p.GetNumOfCNs(201)) require.Equal(t, 4, p.GetNumOfCNs(201))
} }
func TestProtocolConfigurationEquals(t *testing.T) {
p := &ProtocolConfiguration{}
o := &ProtocolConfiguration{}
require.True(t, p.Equals(o))
require.True(t, o.Equals(p))
require.True(t, p.Equals(p))
cfg1, err := LoadFile(filepath.Join("..", "..", "config", "protocol.mainnet.yml"))
require.NoError(t, err)
cfg2, err := LoadFile(filepath.Join("..", "..", "config", "protocol.testnet.yml"))
require.NoError(t, err)
require.False(t, cfg1.ProtocolConfiguration.Equals(&cfg2.ProtocolConfiguration))
cfg2, err = LoadFile(filepath.Join("..", "..", "config", "protocol.mainnet.yml"))
require.NoError(t, err)
p = &cfg1.ProtocolConfiguration
o = &cfg2.ProtocolConfiguration
require.True(t, p.Equals(o))
o.CommitteeHistory = map[uint32]int{111: 7}
p.CommitteeHistory = map[uint32]int{111: 7}
require.True(t, p.Equals(o))
p.CommitteeHistory[111] = 8
require.False(t, p.Equals(o))
o.CommitteeHistory = nil
p.CommitteeHistory = nil
p.Hardforks = map[string]uint32{"Fork": 42}
o.Hardforks = map[string]uint32{"Fork": 42}
require.True(t, p.Equals(o))
p.Hardforks = map[string]uint32{"Fork2": 42}
require.False(t, p.Equals(o))
p.Hardforks = nil
o.Hardforks = nil
p.NativeUpdateHistories = map[string][]uint32{"Contract": {1, 2, 3}}
o.NativeUpdateHistories = map[string][]uint32{"Contract": {1, 2, 3}}
require.True(t, p.Equals(o))
p.NativeUpdateHistories["Contract"] = []uint32{1, 2, 3, 4}
require.False(t, p.Equals(o))
p.NativeUpdateHistories["Contract"] = []uint32{1, 2, 4}
require.False(t, p.Equals(o))
p.NativeUpdateHistories = nil
o.NativeUpdateHistories = nil
p.SeedList = []string{"url1", "url2"}
o.SeedList = []string{"url1", "url2"}
require.True(t, p.Equals(o))
p.SeedList = []string{"url11", "url22"}
require.False(t, p.Equals(o))
p.SeedList = nil
o.SeedList = nil
p.StandbyCommittee = []string{"key1", "key2"}
o.StandbyCommittee = []string{"key1", "key2"}
require.True(t, p.Equals(o))
p.StandbyCommittee = []string{"key2", "key1"}
require.False(t, p.Equals(o))
p.StandbyCommittee = nil
o.StandbyCommittee = nil
o.ValidatorsHistory = map[uint32]int{111: 0}
p.ValidatorsHistory = map[uint32]int{111: 0}
require.True(t, p.Equals(o))
p.ValidatorsHistory = map[uint32]int{112: 0}
require.False(t, p.Equals(o))
}

View file

@ -40,9 +40,6 @@ const defaultTimePerBlock = 15 * time.Second
// Number of nanoseconds in millisecond. // Number of nanoseconds in millisecond.
const nsInMs = 1000000 const nsInMs = 1000000
// Category is a message category for extensible payloads.
const Category = "dBFT"
// Ledger is the interface to Blockchain sufficient for Service. // Ledger is the interface to Blockchain sufficient for Service.
type Ledger interface { type Ledger interface {
AddBlock(block *coreb.Block) error AddBlock(block *coreb.Block) error
@ -218,7 +215,7 @@ var (
func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload { func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload {
return &Payload{ return &Payload{
Extensible: npayload.Extensible{ Extensible: npayload.Extensible{
Category: Category, Category: npayload.ConsensusCategory,
}, },
message: message{ message: message{
stateRootEnabled: stateRootEnabled, stateRootEnabled: stateRootEnabled,
@ -276,6 +273,7 @@ func (s *service) Start() {
// Shutdown implements the Service interface. // Shutdown implements the Service interface.
func (s *service) Shutdown() { func (s *service) Shutdown() {
if s.started.CAS(true, false) { if s.started.CAS(true, false) {
s.log.Info("stopping consensus service")
close(s.quit) close(s.quit)
<-s.finished <-s.finished
} }

View file

@ -297,7 +297,7 @@ func getVerificationScript(i uint8, validators []crypto.PublicKey) []byte {
func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload { func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload {
return &Payload{ return &Payload{
Extensible: npayload.Extensible{ Extensible: npayload.Extensible{
Category: Category, Category: npayload.ConsensusCategory,
ValidBlockEnd: recovery.BlockIndex, ValidBlockEnd: recovery.BlockIndex,
}, },
message: message{ message: message{

View file

@ -303,20 +303,42 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration, log *zap.L
// must be called before `bc.Run()` to avoid data race. // must be called before `bc.Run()` to avoid data race.
func (bc *Blockchain) SetOracle(mod native.OracleService) { func (bc *Blockchain) SetOracle(mod native.OracleService) {
orc := bc.contracts.Oracle orc := bc.contracts.Oracle
md, ok := orc.GetMethod(manifest.MethodVerify, -1) if mod != nil {
if !ok { md, ok := orc.GetMethod(manifest.MethodVerify, -1)
panic(fmt.Errorf("%s method not found", manifest.MethodVerify)) if !ok {
panic(fmt.Errorf("%s method not found", manifest.MethodVerify))
}
mod.UpdateNativeContract(orc.NEF.Script, orc.GetOracleResponseScript(),
orc.Hash, md.MD.Offset)
keys, _, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, noderoles.Oracle, bc.BlockHeight())
if err != nil {
bc.log.Error("failed to get oracle key list")
return
}
mod.UpdateOracleNodes(keys)
reqs, err := bc.contracts.Oracle.GetRequests(bc.dao)
if err != nil {
bc.log.Error("failed to get current oracle request list")
return
}
mod.AddRequests(reqs)
} }
mod.UpdateNativeContract(orc.NEF.Script, orc.GetOracleResponseScript(), orc.Module.Store(&mod)
orc.Hash, md.MD.Offset) bc.contracts.Designate.OracleService.Store(&mod)
orc.Module.Store(mod)
bc.contracts.Designate.OracleService.Store(mod)
} }
// SetNotary sets notary module. It doesn't protected by mutex and // SetNotary sets notary module. It doesn't protected by mutex and
// must be called before `bc.Run()` to avoid data race. // must be called before `bc.Run()` to avoid data race.
func (bc *Blockchain) SetNotary(mod native.NotaryService) { func (bc *Blockchain) SetNotary(mod native.NotaryService) {
bc.contracts.Designate.NotaryService.Store(mod) if mod != nil {
keys, _, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, noderoles.P2PNotary, bc.BlockHeight())
if err != nil {
bc.log.Error("failed to get notary key list")
return
}
mod.UpdateNotaryNodes(keys)
}
bc.contracts.Designate.NotaryService.Store(&mod)
} }
func (bc *Blockchain) init() error { func (bc *Blockchain) init() error {

View file

@ -238,12 +238,12 @@ func (s *Designate) updateCachedRoleData(cache *DesignationCache, d *dao.Simple,
func (s *Designate) notifyRoleChanged(v *roleData, r noderoles.Role) { func (s *Designate) notifyRoleChanged(v *roleData, r noderoles.Role) {
switch r { switch r {
case noderoles.Oracle: case noderoles.Oracle:
if orc, _ := s.OracleService.Load().(OracleService); orc != nil { if orc, _ := s.OracleService.Load().(*OracleService); orc != nil && *orc != nil {
orc.UpdateOracleNodes(v.nodes.Copy()) (*orc).UpdateOracleNodes(v.nodes.Copy())
} }
case noderoles.P2PNotary: case noderoles.P2PNotary:
if ntr, _ := s.NotaryService.Load().(NotaryService); ntr != nil { if ntr, _ := s.NotaryService.Load().(*NotaryService); ntr != nil && *ntr != nil {
ntr.UpdateNotaryNodes(v.nodes.Copy()) (*ntr).UpdateNotaryNodes(v.nodes.Copy())
} }
case noderoles.StateValidator: case noderoles.StateValidator:
if s.StateRootService != nil { if s.StateRootService != nil {

View file

@ -113,7 +113,10 @@ func copyOracleCache(src, dst *OracleCache) {
} }
func newOracle() *Oracle { func newOracle() *Oracle {
o := &Oracle{ContractMD: *interop.NewContractMD(nativenames.Oracle, oracleContractID)} o := &Oracle{
ContractMD: *interop.NewContractMD(nativenames.Oracle, oracleContractID),
newRequests: make(map[uint64]*state.OracleRequest),
}
defer o.UpdateHash() defer o.UpdateHash()
o.oracleScript = CreateOracleResponseScript(o.Hash) o.oracleScript = CreateOracleResponseScript(o.Hash)
@ -161,11 +164,7 @@ func (o *Oracle) GetOracleResponseScript() []byte {
// OnPersist implements the Contract interface. // OnPersist implements the Contract interface.
func (o *Oracle) OnPersist(ic *interop.Context) error { func (o *Oracle) OnPersist(ic *interop.Context) error {
var err error return nil
if o.newRequests == nil {
o.newRequests, err = o.getRequests(ic.DAO)
}
return err
} }
// PostPersist represents `postPersist` method. // PostPersist represents `postPersist` method.
@ -177,7 +176,7 @@ func (o *Oracle) PostPersist(ic *interop.Context) error {
single := big.NewInt(p) single := big.NewInt(p)
var removedIDs []uint64 var removedIDs []uint64
orc, _ := o.Module.Load().(OracleService) orc, _ := o.Module.Load().(*OracleService)
for _, tx := range ic.Block.Transactions { for _, tx := range ic.Block.Transactions {
resp := getResponse(tx) resp := getResponse(tx)
if resp == nil { if resp == nil {
@ -189,7 +188,7 @@ func (o *Oracle) PostPersist(ic *interop.Context) error {
continue continue
} }
ic.DAO.DeleteStorageItem(o.ID, reqKey) ic.DAO.DeleteStorageItem(o.ID, reqKey)
if orc != nil { if orc != nil && *orc != nil {
removedIDs = append(removedIDs, resp.ID) removedIDs = append(removedIDs, resp.ID)
} }
@ -229,8 +228,8 @@ func (o *Oracle) PostPersist(ic *interop.Context) error {
o.GAS.mint(ic, nodes[i].GetScriptHash(), &reward[i], false) o.GAS.mint(ic, nodes[i].GetScriptHash(), &reward[i], false)
} }
if len(removedIDs) != 0 && orc != nil { if len(removedIDs) != 0 {
orc.RemoveRequests(removedIDs) (*orc).RemoveRequests(removedIDs)
} }
return o.updateCache(ic.DAO) return o.updateCache(ic.DAO)
} }
@ -415,7 +414,10 @@ func (o *Oracle) PutRequestInternal(id uint64, req *state.OracleRequest, d *dao.
if err := putConvertibleToDAO(o.ID, d, reqKey, req); err != nil { if err := putConvertibleToDAO(o.ID, d, reqKey, req); err != nil {
return err return err
} }
o.newRequests[id] = req orc, _ := o.Module.Load().(*OracleService)
if orc != nil && *orc != nil {
o.newRequests[id] = req
}
// Add request ID to the id list. // Add request ID to the id list.
lst := new(IDList) lst := new(IDList)
@ -493,8 +495,8 @@ func (o *Oracle) getOriginalTxID(d *dao.Simple, tx *transaction.Transaction) uti
return tx.Hash() return tx.Hash()
} }
// getRequests returns all requests which have not been finished yet. // GetRequests returns all requests which have not been finished yet.
func (o *Oracle) getRequests(d *dao.Simple) (map[uint64]*state.OracleRequest, error) { func (o *Oracle) GetRequests(d *dao.Simple) (map[uint64]*state.OracleRequest, error) {
var reqs = make(map[uint64]*state.OracleRequest) var reqs = make(map[uint64]*state.OracleRequest)
var err error var err error
d.Seek(o.ID, storage.SeekRange{Prefix: prefixRequest}, func(k, v []byte) bool { d.Seek(o.ID, storage.SeekRange{Prefix: prefixRequest}, func(k, v []byte) bool {
@ -534,8 +536,8 @@ func (o *Oracle) getConvertibleFromDAO(d *dao.Simple, key []byte, item stackitem
// updateCache updates cached Oracle values if they've been changed. // updateCache updates cached Oracle values if they've been changed.
func (o *Oracle) updateCache(d *dao.Simple) error { func (o *Oracle) updateCache(d *dao.Simple) error {
orc, _ := o.Module.Load().(OracleService) orc, _ := o.Module.Load().(*OracleService)
if orc == nil { if orc == nil || *orc == nil {
return nil return nil
} }
@ -547,7 +549,7 @@ func (o *Oracle) updateCache(d *dao.Simple) error {
delete(reqs, id) delete(reqs, id)
} }
} }
orc.AddRequests(reqs) (*orc).AddRequests(reqs)
return nil return nil
} }

View file

@ -10,7 +10,6 @@ import (
"github.com/nspcc-dev/neo-go/internal/fakechain" "github.com/nspcc-dev/neo-go/internal/fakechain"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
@ -196,10 +195,6 @@ func newTestServerWithCustomCfg(t *testing.T, serverConfig ServerConfig, protoco
s, err := newServerFromConstructors(serverConfig, fakechain.NewFakeChainWithCustomCfg(protocolCfg), new(fakechain.FakeStateSync), zaptest.NewLogger(t), s, err := newServerFromConstructors(serverConfig, fakechain.NewFakeChainWithCustomCfg(protocolCfg), new(fakechain.FakeStateSync), zaptest.NewLogger(t),
newFakeTransp, newTestDiscovery) newFakeTransp, newTestDiscovery)
require.NoError(t, err) require.NoError(t, err)
if serverConfig.Wallet != nil {
cons := new(fakeConsensus)
s.AddExtensibleHPService(cons, consensus.Category, cons.OnPayload, cons.OnTransaction)
}
t.Cleanup(s.discovery.Close) t.Cleanup(s.discovery.Close)
return s return s
} }

View file

@ -11,6 +11,10 @@ import (
const maxExtensibleCategorySize = 32 const maxExtensibleCategorySize = 32
// ConsensusCategory is a message category for consensus-related extensible
// payloads.
const ConsensusCategory = "dBFT"
// Extensible represents a payload containing arbitrary data. // Extensible represents a payload containing arbitrary data.
type Extensible struct { type Extensible struct {
// Category is the payload type. // Category is the payload type.

View file

@ -101,10 +101,11 @@ type (
notaryRequestPool *mempool.Pool notaryRequestPool *mempool.Pool
extensiblePool *extpool.Pool extensiblePool *extpool.Pool
notaryFeer NotaryFeer notaryFeer NotaryFeer
services map[string]Service
extensHandlers map[string]func(*payload.Extensible) error serviceLock sync.RWMutex
extensHighPrio string services map[string]Service
txCallback func(*transaction.Transaction) extensHandlers map[string]func(*payload.Extensible) error
txCallback func(*transaction.Transaction)
txInLock sync.Mutex txInLock sync.Mutex
txInMap map[util.Uint256]struct{} txInMap map[util.Uint256]struct{}
@ -263,9 +264,11 @@ func (s *Server) Shutdown() {
} }
s.bQueue.discard() s.bQueue.discard()
s.bSyncQueue.discard() s.bSyncQueue.discard()
s.serviceLock.RLock()
for _, svc := range s.services { for _, svc := range s.services {
svc.Shutdown() svc.Shutdown()
} }
s.serviceLock.RUnlock()
if s.chain.P2PSigExtensionsEnabled() { if s.chain.P2PSigExtensionsEnabled() {
s.notaryRequestPool.StopSubscriptions() s.notaryRequestPool.StopSubscriptions()
} }
@ -274,20 +277,70 @@ func (s *Server) Shutdown() {
// AddService allows to add a service to be started/stopped by Server. // AddService allows to add a service to be started/stopped by Server.
func (s *Server) AddService(svc Service) { func (s *Server) AddService(svc Service) {
s.serviceLock.Lock()
defer s.serviceLock.Unlock()
s.addService(svc)
}
// addService is an unlocked version of AddService.
func (s *Server) addService(svc Service) {
s.services[svc.Name()] = svc s.services[svc.Name()] = svc
} }
// AddExtensibleService register a service that handles an extensible payload of some kind. // AddExtensibleService register a service that handles an extensible payload of some kind.
func (s *Server) AddExtensibleService(svc Service, category string, handler func(*payload.Extensible) error) { func (s *Server) AddExtensibleService(svc Service, category string, handler func(*payload.Extensible) error) {
s.extensHandlers[category] = handler s.serviceLock.Lock()
s.AddService(svc) defer s.serviceLock.Unlock()
s.addExtensibleService(svc, category, handler)
} }
// AddExtensibleHPService registers a high-priority service that handles an extensible payload of some kind. // addExtensibleService is an unlocked version of AddExtensibleService.
func (s *Server) AddExtensibleHPService(svc Service, category string, handler func(*payload.Extensible) error, txCallback func(*transaction.Transaction)) { func (s *Server) addExtensibleService(svc Service, category string, handler func(*payload.Extensible) error) {
s.extensHandlers[category] = handler
s.addService(svc)
}
// AddConsensusService registers consensus service that handles transactions and dBFT extensible payloads.
func (s *Server) AddConsensusService(svc Service, handler func(*payload.Extensible) error, txCallback func(*transaction.Transaction)) {
s.serviceLock.Lock()
defer s.serviceLock.Unlock()
s.txCallback = txCallback s.txCallback = txCallback
s.extensHighPrio = category s.addExtensibleService(svc, payload.ConsensusCategory, handler)
s.AddExtensibleService(svc, category, handler) }
// DelService drops a service from the list, use it when the service is stopped
// outside of the Server.
func (s *Server) DelService(svc Service) {
s.serviceLock.Lock()
defer s.serviceLock.Unlock()
s.delService(svc)
}
// delService is an unlocked version of DelService.
func (s *Server) delService(svc Service) {
delete(s.services, svc.Name())
}
// DelExtensibleService drops a service that handler extensible payloads from the
// list, use it when the service is stopped outside of the Server.
func (s *Server) DelExtensibleService(svc Service, category string) {
s.serviceLock.Lock()
defer s.serviceLock.Unlock()
s.delExtensibleService(svc, category)
}
// delExtensibleService is an unlocked version of DelExtensibleService.
func (s *Server) delExtensibleService(svc Service, category string) {
delete(s.extensHandlers, category)
s.delService(svc)
}
// DelConsensusService unregisters consensus service that handles transactions and dBFT extensible payloads.
func (s *Server) DelConsensusService(svc Service) {
s.serviceLock.Lock()
defer s.serviceLock.Unlock()
s.txCallback = nil
s.delExtensibleService(svc, payload.ConsensusCategory)
} }
// GetNotaryPool allows to retrieve notary pool, if it's configured. // GetNotaryPool allows to retrieve notary pool, if it's configured.
@ -428,9 +481,11 @@ func (s *Server) tryStartServices() {
if s.chain.P2PSigExtensionsEnabled() { if s.chain.P2PSigExtensionsEnabled() {
s.notaryRequestPool.RunSubscriptions() // WSClient is also a subscriber. s.notaryRequestPool.RunSubscriptions() // WSClient is also a subscriber.
} }
s.serviceLock.RLock()
for _, svc := range s.services { for _, svc := range s.services {
svc.Start() svc.Start()
} }
s.serviceLock.RUnlock()
} }
} }
@ -931,7 +986,9 @@ func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
if !ok { // payload is already in cache if !ok { // payload is already in cache
return nil return nil
} }
s.serviceLock.RLock()
handler := s.extensHandlers[e.Category] handler := s.extensHandlers[e.Category]
s.serviceLock.RUnlock()
if handler != nil { if handler != nil {
err = handler(e) err = handler(e)
if err != nil { if err != nil {
@ -944,7 +1001,7 @@ func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
func (s *Server) advertiseExtensible(e *payload.Extensible) { func (s *Server) advertiseExtensible(e *payload.Extensible) {
msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{e.Hash()})) msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{e.Hash()}))
if e.Category == s.extensHighPrio { if e.Category == payload.ConsensusCategory {
// It's high priority because it directly affects consensus process, // It's high priority because it directly affects consensus process,
// even though it's just an inv. // even though it's just an inv.
s.broadcastHPMessage(msg) s.broadcastHPMessage(msg)
@ -966,8 +1023,11 @@ func (s *Server) handleTxCmd(tx *transaction.Transaction) error {
} }
s.txInMap[tx.Hash()] = struct{}{} s.txInMap[tx.Hash()] = struct{}{}
s.txInLock.Unlock() s.txInLock.Unlock()
if s.txCallback != nil { s.serviceLock.RLock()
s.txCallback(tx) txCallback := s.txCallback
s.serviceLock.RUnlock()
if txCallback != nil {
txCallback(tx)
} }
if s.verifyAndPoolTX(tx) == nil { if s.verifyAndPoolTX(tx) == nil {
s.broadcastTX(tx, nil) s.broadcastTX(tx, nil)

View file

@ -64,9 +64,6 @@ type (
// Level of the internal logger. // Level of the internal logger.
LogLevel zapcore.Level LogLevel zapcore.Level
// Wallet is a wallet configuration.
Wallet *config.Wallet
// TimePerBlock is an interval which should pass between two successive blocks. // TimePerBlock is an interval which should pass between two successive blocks.
TimePerBlock time.Duration TimePerBlock time.Duration
@ -90,11 +87,6 @@ func NewServerConfig(cfg config.Config) ServerConfig {
appConfig := cfg.ApplicationConfiguration appConfig := cfg.ApplicationConfiguration
protoConfig := cfg.ProtocolConfiguration protoConfig := cfg.ProtocolConfiguration
var wc *config.Wallet
if appConfig.UnlockWallet.Path != "" {
wc = &appConfig.UnlockWallet
}
return ServerConfig{ return ServerConfig{
UserAgent: cfg.GenerateUserAgent(), UserAgent: cfg.GenerateUserAgent(),
Address: appConfig.Address, Address: appConfig.Address,
@ -110,7 +102,6 @@ func NewServerConfig(cfg config.Config) ServerConfig {
MaxPeers: appConfig.MaxPeers, MaxPeers: appConfig.MaxPeers,
AttemptConnPeers: appConfig.AttemptConnPeers, AttemptConnPeers: appConfig.AttemptConnPeers,
MinPeers: appConfig.MinPeers, MinPeers: appConfig.MinPeers,
Wallet: wc,
TimePerBlock: time.Duration(protoConfig.SecondsPerBlock) * time.Second, TimePerBlock: time.Duration(protoConfig.SecondsPerBlock) * time.Second,
OracleCfg: appConfig.Oracle, OracleCfg: appConfig.Oracle,
P2PNotaryCfg: appConfig.P2PNotary, P2PNotaryCfg: appConfig.P2PNotary,

View file

@ -109,7 +109,9 @@ func TestServerStartAndShutdown(t *testing.T) {
require.True(t, errors.Is(err, errServerShutdown)) require.True(t, errors.Is(err, errServerShutdown))
}) })
t.Run("with consensus", func(t *testing.T) { t.Run("with consensus", func(t *testing.T) {
s := newTestServer(t, ServerConfig{Wallet: new(config.Wallet)}) s := newTestServer(t, ServerConfig{})
cons := new(fakeConsensus)
s.AddConsensusService(cons, cons.OnPayload, cons.OnTransaction)
ch := startWithChannel(s) ch := startWithChannel(s)
p := newLocalPeer(t, s) p := newLocalPeer(t, s)
@ -409,7 +411,9 @@ func TestBlock(t *testing.T) {
} }
func TestConsensus(t *testing.T) { func TestConsensus(t *testing.T) {
s := newTestServer(t, ServerConfig{Wallet: new(config.Wallet)}) s := newTestServer(t, ServerConfig{})
cons := new(fakeConsensus)
s.AddConsensusService(cons, cons.OnPayload, cons.OnTransaction)
startWithCleanup(t, s) startWithCleanup(t, s)
atomic2.StoreUint32(&s.chain.(*fakechain.FakeChain).Blockheight, 4) atomic2.StoreUint32(&s.chain.(*fakechain.FakeChain).Blockheight, 4)
@ -420,7 +424,7 @@ func TestConsensus(t *testing.T) {
newConsensusMessage := func(start, end uint32) *Message { newConsensusMessage := func(start, end uint32) *Message {
pl := payload.NewExtensible() pl := payload.NewExtensible()
pl.Category = consensus.Category pl.Category = payload.ConsensusCategory
pl.ValidBlockStart = start pl.ValidBlockStart = start
pl.ValidBlockEnd = end pl.ValidBlockEnd = end
return NewMessage(CMDExtensible, pl) return NewMessage(CMDExtensible, pl)
@ -452,7 +456,9 @@ func TestConsensus(t *testing.T) {
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
s := newTestServer(t, ServerConfig{Wallet: new(config.Wallet)}) s := newTestServer(t, ServerConfig{})
cons := new(fakeConsensus)
s.AddConsensusService(cons, cons.OnPayload, cons.OnTransaction)
startWithCleanup(t, s) startWithCleanup(t, s)
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {

View file

@ -31,9 +31,12 @@ func (ms *Service) Start() {
// ShutDown stops the service. // ShutDown stops the service.
func (ms *Service) ShutDown() { func (ms *Service) ShutDown() {
if !ms.config.Enabled {
return
}
ms.log.Info("shutting down service", zap.String("endpoint", ms.Addr)) ms.log.Info("shutting down service", zap.String("endpoint", ms.Addr))
err := ms.Shutdown(context.Background()) err := ms.Shutdown(context.Background())
if err != nil { if err != nil {
ms.log.Panic("can't shut down service") ms.log.Error("can't shut service down", zap.Error(err))
} }
} }

View file

@ -219,6 +219,7 @@ func (n *Notary) Shutdown() {
if !n.started.CAS(true, false) { if !n.started.CAS(true, false) {
return return
} }
n.Config.Log.Info("stopping notary service")
close(n.stopCh) close(n.stopCh)
<-n.done <-n.done
} }

View file

@ -10,7 +10,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/rpcclient" "github.com/nspcc-dev/neo-go/pkg/rpcclient"
"github.com/nspcc-dev/neo-go/pkg/services/helpers/rpcbroadcaster" "github.com/nspcc-dev/neo-go/pkg/services/helpers/rpcbroadcaster"
"github.com/nspcc-dev/neo-go/pkg/services/oracle"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -20,16 +19,17 @@ const (
defaultChanCapacity = 16 defaultChanCapacity = 16
) )
type oracleBroadcaster struct { // OracleBroadcaster is an oracle broadcaster implementation.
type OracleBroadcaster struct {
rpcbroadcaster.RPCBroadcaster rpcbroadcaster.RPCBroadcaster
} }
// New returns a new struct capable of broadcasting oracle responses. // New returns a new struct capable of broadcasting oracle responses.
func New(cfg config.OracleConfiguration, log *zap.Logger) oracle.Broadcaster { func New(cfg config.OracleConfiguration, log *zap.Logger) *OracleBroadcaster {
if cfg.ResponseTimeout == 0 { if cfg.ResponseTimeout == 0 {
cfg.ResponseTimeout = defaultSendTimeout cfg.ResponseTimeout = defaultSendTimeout
} }
r := &oracleBroadcaster{ r := &OracleBroadcaster{
RPCBroadcaster: *rpcbroadcaster.NewRPCBroadcaster(log, cfg.ResponseTimeout), RPCBroadcaster: *rpcbroadcaster.NewRPCBroadcaster(log, cfg.ResponseTimeout),
} }
for i := range cfg.Nodes { for i := range cfg.Nodes {
@ -40,7 +40,7 @@ func New(cfg config.OracleConfiguration, log *zap.Logger) oracle.Broadcaster {
} }
// SendResponse implements interfaces.Broadcaster. // SendResponse implements interfaces.Broadcaster.
func (r *oracleBroadcaster) SendResponse(priv *keys.PrivateKey, resp *transaction.OracleResponse, txSig []byte) { func (r *OracleBroadcaster) SendResponse(priv *keys.PrivateKey, resp *transaction.OracleResponse, txSig []byte) {
pub := priv.PublicKey() pub := priv.PublicKey()
data := GetMessage(pub.Bytes(), resp.ID, txSig) data := GetMessage(pub.Bytes(), resp.ID, txSig)
msgSig := priv.Sign(data) msgSig := priv.Sign(data)

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/services/oracle/broadcaster"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/util/slice"
@ -43,9 +44,6 @@ type (
oracleScript []byte oracleScript []byte
verifyOffset int verifyOffset int
// mtx protects setting callbacks.
mtx sync.RWMutex
// accMtx protects account and oracle nodes. // accMtx protects account and oracle nodes.
accMtx sync.RWMutex accMtx sync.RWMutex
currAccount *wallet.Account currAccount *wallet.Account
@ -94,8 +92,6 @@ type (
Shutdown() Shutdown()
} }
defaultResponseHandler struct{}
// TxCallback executes on new transactions when they are ready to be pooled. // TxCallback executes on new transactions when they are ready to be pooled.
TxCallback = func(tx *transaction.Transaction) error TxCallback = func(tx *transaction.Transaction) error
) )
@ -165,7 +161,7 @@ func NewOracle(cfg Config) (*Oracle, error) {
} }
if o.ResponseHandler == nil { if o.ResponseHandler == nil {
o.ResponseHandler = defaultResponseHandler{} o.ResponseHandler = broadcaster.New(cfg.MainCfg, cfg.Log)
} }
if o.OnTransaction == nil { if o.OnTransaction == nil {
o.OnTransaction = func(*transaction.Transaction) error { return nil } o.OnTransaction = func(*transaction.Transaction) error { return nil }
@ -190,9 +186,10 @@ func (o *Oracle) Shutdown() {
if !o.running { if !o.running {
return return
} }
o.Log.Info("stopping oracle service")
o.running = false o.running = false
close(o.close) close(o.close)
o.getBroadcaster().Shutdown() o.ResponseHandler.Shutdown()
<-o.done <-o.done
} }
@ -217,6 +214,7 @@ func (o *Oracle) start() {
for i := 0; i < o.MainCfg.MaxConcurrentRequests; i++ { for i := 0; i < o.MainCfg.MaxConcurrentRequests; i++ {
go o.runRequestWorker() go o.runRequestWorker()
} }
go o.ResponseHandler.Run()
tick := time.NewTicker(o.MainCfg.RefreshInterval) tick := time.NewTicker(o.MainCfg.RefreshInterval)
main: main:
@ -284,28 +282,3 @@ func (o *Oracle) sendTx(tx *transaction.Transaction) {
zap.Error(err)) zap.Error(err))
} }
} }
func (o *Oracle) getBroadcaster() Broadcaster {
o.mtx.RLock()
defer o.mtx.RUnlock()
return o.ResponseHandler
}
// SetBroadcaster sets callback to broadcast response.
func (o *Oracle) SetBroadcaster(b Broadcaster) {
o.mtx.Lock()
defer o.mtx.Unlock()
o.ResponseHandler.Shutdown()
o.ResponseHandler = b
go b.Run()
}
// SendResponse implements Broadcaster interface.
func (defaultResponseHandler) SendResponse(*keys.PrivateKey, *transaction.OracleResponse, []byte) {
}
// Run implements Broadcaster interface.
func (defaultResponseHandler) Run() {}
// Shutdown implements Broadcaster interface.
func (defaultResponseHandler) Shutdown() {}

View file

@ -121,8 +121,8 @@ func TestCreateResponseTx(t *testing.T) {
Result: []byte{0}, Result: []byte{0},
} }
cInvoker.Invoke(t, stackitem.Null{}, "requestURL", req.URL, *req.Filter, req.CallbackMethod, req.UserData, int64(req.GasForResponse)) cInvoker.Invoke(t, stackitem.Null{}, "requestURL", req.URL, *req.Filter, req.CallbackMethod, req.UserData, int64(req.GasForResponse))
orc.UpdateOracleNodes(keys.PublicKeys{acc.PrivateKey().PublicKey()})
bc.SetOracle(orc) bc.SetOracle(orc)
orc.UpdateOracleNodes(keys.PublicKeys{acc.PrivateKey().PublicKey()})
tx, err = orc.CreateResponseTx(int64(req.GasForResponse), 1, resp) tx, err = orc.CreateResponseTx(int64(req.GasForResponse), 1, resp)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 166, tx.Size()) assert.Equal(t, 166, tx.Size())

View file

@ -234,7 +234,7 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
incTx.attempts++ incTx.attempts++
incTx.Unlock() incTx.Unlock()
o.getBroadcaster().SendResponse(priv, resp, txSig) o.ResponseHandler.SendResponse(priv, resp, txSig)
if ready { if ready {
o.sendTx(readyTx) o.sendTx(readyTx)
} }
@ -265,7 +265,7 @@ func (o *Oracle) processFailedRequest(priv *keys.PrivateKey, req request) {
txSig := incTx.backupSigs[string(priv.PublicKey().Bytes())].sig txSig := incTx.backupSigs[string(priv.PublicKey().Bytes())].sig
incTx.Unlock() incTx.Unlock()
o.getBroadcaster().SendResponse(priv, getFailedResponse(req.ID), txSig) o.ResponseHandler.SendResponse(priv, getFailedResponse(req.ID), txSig)
if ready { if ready {
o.sendTx(readyTx) o.sendTx(readyTx)
} }

View file

@ -43,7 +43,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/network" "github.com/nspcc-dev/neo-go/pkg/network"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/services/oracle"
"github.com/nspcc-dev/neo-go/pkg/services/oracle/broadcaster" "github.com/nspcc-dev/neo-go/pkg/services/oracle/broadcaster"
"github.com/nspcc-dev/neo-go/pkg/services/rpcsrv/params" "github.com/nspcc-dev/neo-go/pkg/services/rpcsrv/params"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
@ -108,6 +107,11 @@ type (
mempool.Feer // fee interface mempool.Feer // fee interface
} }
// OracleHandler is the interface oracle service needs to provide for the Server.
OracleHandler interface {
AddResponse(pub *keys.PublicKey, reqID uint64, txSig []byte)
}
// Server represents the JSON-RPC 2.0 server. // Server represents the JSON-RPC 2.0 server.
Server struct { Server struct {
*http.Server *http.Server
@ -118,7 +122,7 @@ type (
network netmode.Magic network netmode.Magic
stateRootEnabled bool stateRootEnabled bool
coreServer *network.Server coreServer *network.Server
oracle *oracle.Oracle oracle *atomic.Value
log *zap.Logger log *zap.Logger
https *http.Server https *http.Server
shutdown chan struct{} shutdown chan struct{}
@ -248,7 +252,7 @@ var upgrader = websocket.Upgrader{}
// New creates a new Server struct. // New creates a new Server struct.
func New(chain Ledger, conf config.RPC, coreServer *network.Server, func New(chain Ledger, conf config.RPC, coreServer *network.Server,
orc *oracle.Oracle, log *zap.Logger, errChan chan error) Server { orc OracleHandler, log *zap.Logger, errChan chan error) Server {
httpServer := &http.Server{ httpServer := &http.Server{
Addr: conf.Address + ":" + strconv.FormatUint(uint64(conf.Port), 10), Addr: conf.Address + ":" + strconv.FormatUint(uint64(conf.Port), 10),
} }
@ -260,9 +264,6 @@ func New(chain Ledger, conf config.RPC, coreServer *network.Server,
} }
} }
if orc != nil {
orc.SetBroadcaster(broadcaster.New(orc.MainCfg, log))
}
protoCfg := chain.GetConfig() protoCfg := chain.GetConfig()
if conf.SessionEnabled { if conf.SessionEnabled {
if conf.SessionExpirationTime <= 0 { if conf.SessionExpirationTime <= 0 {
@ -274,6 +275,10 @@ func New(chain Ledger, conf config.RPC, coreServer *network.Server,
log.Info("SessionPoolSize is not set or wrong, setting default value", zap.Int("SessionPoolSize", defaultSessionPoolSize)) log.Info("SessionPoolSize is not set or wrong, setting default value", zap.Int("SessionPoolSize", defaultSessionPoolSize))
} }
} }
var oracleWrapped = new(atomic.Value)
if orc != nil {
oracleWrapped.Store(&orc)
}
return Server{ return Server{
Server: httpServer, Server: httpServer,
chain: chain, chain: chain,
@ -283,7 +288,7 @@ func New(chain Ledger, conf config.RPC, coreServer *network.Server,
stateRootEnabled: protoCfg.StateRootInHeader, stateRootEnabled: protoCfg.StateRootInHeader,
coreServer: coreServer, coreServer: coreServer,
log: log, log: log,
oracle: orc, oracle: oracleWrapped,
https: tlsServer, https: tlsServer,
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
started: atomic.NewBool(false), started: atomic.NewBool(false),
@ -399,6 +404,11 @@ func (s *Server) Shutdown() {
<-s.executionCh <-s.executionCh
} }
// SetOracleHandler allows to update oracle handler used by the Server.
func (s *Server) SetOracleHandler(orc OracleHandler) {
s.oracle.Store(&orc)
}
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
req := params.NewRequest() req := params.NewRequest()
@ -2327,7 +2337,8 @@ func getRelayResult(err error, hash util.Uint256) (interface{}, *neorpc.Error) {
} }
func (s *Server) submitOracleResponse(ps params.Params) (interface{}, *neorpc.Error) { func (s *Server) submitOracleResponse(ps params.Params) (interface{}, *neorpc.Error) {
if s.oracle == nil { oracle := s.oracle.Load().(*OracleHandler)
if oracle == nil || *oracle == nil {
return nil, neorpc.NewRPCError("Oracle is not enabled", "") return nil, neorpc.NewRPCError("Oracle is not enabled", "")
} }
var pub *keys.PublicKey var pub *keys.PublicKey
@ -2354,7 +2365,7 @@ func (s *Server) submitOracleResponse(ps params.Params) (interface{}, *neorpc.Er
if !pub.Verify(msgSig, hash.Sha256(data).BytesBE()) { if !pub.Verify(msgSig, hash.Sha256(data).BytesBE()) {
return nil, neorpc.NewRPCError("Invalid request signature", "") return nil, neorpc.NewRPCError("Invalid request signature", "")
} }
s.oracle.AddResponse(pub, uint64(reqID), txSig) (*oracle).AddResponse(pub, uint64(reqID), txSig)
return json.RawMessage([]byte("{}")), nil return json.RawMessage([]byte("{}")), nil
} }

View file

@ -71,6 +71,7 @@ func (s *service) Shutdown() {
if !s.started.CAS(true, false) { if !s.started.CAS(true, false) {
return return
} }
s.log.Info("stopping state validation service")
close(s.stopCh) close(s.stopCh)
<-s.done <-s.done
} }