forked from TrueCloudLab/neoneo-go
rpcsrv: carefully store Oracle service
And simplify atomic service value stored by RPC server. Oracle service can either be an untyped nil or be the proper non-nil *oracle.Oracle. Otherwise `submitoracleresponse` RPC handler doesn't work properly. Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
This commit is contained in:
parent
6c1240d023
commit
c39153756a
3 changed files with 21 additions and 12 deletions
|
@ -359,7 +359,14 @@ func resetDB(ctx *cli.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*oracle.Oracle, error) {
|
// oracleService is an interface representing Oracle service with network.Service
|
||||||
|
// capabilities and ability to submit oracle responses.
|
||||||
|
type oracleService interface {
|
||||||
|
rpcsrv.OracleHandler
|
||||||
|
network.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (oracleService, error) {
|
||||||
if !config.Enabled {
|
if !config.Enabled {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -254,7 +254,8 @@ var invalidBlockHeightError = func(index int, height int) *neorpc.Error {
|
||||||
return neorpc.NewRPCError("Invalid block height", fmt.Sprintf("param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height))
|
return neorpc.NewRPCError("Invalid block height", fmt.Sprintf("param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height))
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Server struct.
|
// New creates a new Server struct. Pay attention that orc is expected to be either
|
||||||
|
// untyped nil or non-nil structure implementing OracleHandler interface.
|
||||||
func New(chain Ledger, conf config.RPC, coreServer *network.Server,
|
func New(chain Ledger, conf config.RPC, coreServer *network.Server,
|
||||||
orc OracleHandler, log *zap.Logger, errChan chan<- error) Server {
|
orc OracleHandler, log *zap.Logger, errChan chan<- error) Server {
|
||||||
addrs := conf.GetAddresses()
|
addrs := conf.GetAddresses()
|
||||||
|
@ -293,7 +294,7 @@ func New(chain Ledger, conf config.RPC, coreServer *network.Server,
|
||||||
}
|
}
|
||||||
var oracleWrapped = new(atomic.Value)
|
var oracleWrapped = new(atomic.Value)
|
||||||
if orc != nil {
|
if orc != nil {
|
||||||
oracleWrapped.Store(&orc)
|
oracleWrapped.Store(orc)
|
||||||
}
|
}
|
||||||
var wsOriginChecker func(*http.Request) bool
|
var wsOriginChecker func(*http.Request) bool
|
||||||
if conf.EnableCORSWorkaround {
|
if conf.EnableCORSWorkaround {
|
||||||
|
@ -445,7 +446,7 @@ func (s *Server) Shutdown() {
|
||||||
|
|
||||||
// SetOracleHandler allows to update oracle handler used by the Server.
|
// SetOracleHandler allows to update oracle handler used by the Server.
|
||||||
func (s *Server) SetOracleHandler(orc OracleHandler) {
|
func (s *Server) SetOracleHandler(orc OracleHandler) {
|
||||||
s.oracle.Store(&orc)
|
s.oracle.Store(orc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
|
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
|
||||||
|
@ -2461,10 +2462,11 @@ func getRelayResult(err error, hash util.Uint256) (any, *neorpc.Error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) {
|
func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) {
|
||||||
oracle := s.oracle.Load().(*OracleHandler)
|
oraclePtr := s.oracle.Load()
|
||||||
if oracle == nil || *oracle == nil {
|
if oraclePtr == nil {
|
||||||
return nil, neorpc.NewRPCError("Oracle is not enabled", "")
|
return nil, neorpc.NewRPCError("Oracle is not enabled", "")
|
||||||
}
|
}
|
||||||
|
oracle := oraclePtr.(OracleHandler)
|
||||||
var pub *keys.PublicKey
|
var pub *keys.PublicKey
|
||||||
pubBytes, err := ps.Value(0).GetBytesBase64()
|
pubBytes, err := ps.Value(0).GetBytesBase64()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -2489,7 +2491,7 @@ func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) {
|
||||||
if !pub.Verify(msgSig, hash.Sha256(data).BytesBE()) {
|
if !pub.Verify(msgSig, hash.Sha256(data).BytesBE()) {
|
||||||
return nil, neorpc.NewRPCError("Invalid request signature", "")
|
return nil, neorpc.NewRPCError("Invalid request signature", "")
|
||||||
}
|
}
|
||||||
(*oracle).AddResponse(pub, uint64(reqID), txSig)
|
oracle.AddResponse(pub, uint64(reqID), txSig)
|
||||||
return json.RawMessage([]byte("{}")), nil
|
return json.RawMessage([]byte("{}")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ const (
|
||||||
notaryPass = "one"
|
notaryPass = "one"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disableIteratorSessions bool) (*core.Blockchain, *oracle.Oracle, config.Config, *zap.Logger) {
|
func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disableIteratorSessions bool) (*core.Blockchain, OracleHandler, config.Config, *zap.Logger) {
|
||||||
return getUnitTestChainWithCustomConfig(t, enableOracle, enableNotary, func(cfg *config.Config) {
|
return getUnitTestChainWithCustomConfig(t, enableOracle, enableNotary, func(cfg *config.Config) {
|
||||||
if disableIteratorSessions {
|
if disableIteratorSessions {
|
||||||
cfg.ApplicationConfiguration.RPC.SessionEnabled = false
|
cfg.ApplicationConfiguration.RPC.SessionEnabled = false
|
||||||
|
@ -56,7 +56,7 @@ func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disabl
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNotary bool, customCfg func(configuration *config.Config)) (*core.Blockchain, *oracle.Oracle, config.Config, *zap.Logger) {
|
func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNotary bool, customCfg func(configuration *config.Config)) (*core.Blockchain, OracleHandler, config.Config, *zap.Logger) {
|
||||||
net := netmode.UnitTestNet
|
net := netmode.UnitTestNet
|
||||||
configPath := "../../../config"
|
configPath := "../../../config"
|
||||||
cfg, err := config.Load(configPath, net)
|
cfg, err := config.Load(configPath, net)
|
||||||
|
@ -70,7 +70,7 @@ func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNot
|
||||||
chain, err := core.NewBlockchain(memoryStore, cfg.Blockchain(), logger)
|
chain, err := core.NewBlockchain(memoryStore, cfg.Blockchain(), logger)
|
||||||
require.NoError(t, err, "could not create chain")
|
require.NoError(t, err, "could not create chain")
|
||||||
|
|
||||||
var orc *oracle.Oracle
|
var orc OracleHandler
|
||||||
if enableOracle {
|
if enableOracle {
|
||||||
orc, err = oracle.NewOracle(oracle.Config{
|
orc, err = oracle.NewOracle(oracle.Config{
|
||||||
Log: logger,
|
Log: logger,
|
||||||
|
@ -79,7 +79,7 @@ func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNot
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
chain.SetOracle(orc)
|
chain.SetOracle(orc.(*oracle.Oracle))
|
||||||
}
|
}
|
||||||
|
|
||||||
go chain.Run()
|
go chain.Run()
|
||||||
|
@ -115,7 +115,7 @@ func initClearServerWithServices(t testing.TB, needOracle bool, needNotary bool,
|
||||||
return wrapUnitTestChain(t, chain, orc, cfg, logger)
|
return wrapUnitTestChain(t, chain, orc, cfg, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapUnitTestChain(t testing.TB, chain *core.Blockchain, orc *oracle.Oracle, cfg config.Config, logger *zap.Logger) (*core.Blockchain, *Server, *httptest.Server) {
|
func wrapUnitTestChain(t testing.TB, chain *core.Blockchain, orc OracleHandler, cfg config.Config, logger *zap.Logger) (*core.Blockchain, *Server, *httptest.Server) {
|
||||||
serverConfig, err := network.NewServerConfig(cfg)
|
serverConfig, err := network.NewServerConfig(cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
serverConfig.UserAgent = fmt.Sprintf(config.UserAgentFormat, "0.98.6-test")
|
serverConfig.UserAgent = fmt.Sprintf(config.UserAgentFormat, "0.98.6-test")
|
||||||
|
|
Loading…
Reference in a new issue