From c39153756a01a83c4722a45e0e830e4004967813 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 9 Aug 2023 15:14:06 +0300 Subject: [PATCH] rpcsrv: carefully store Oracle service And simplify atomic service value stored by RPC server. Oracle service can either be an untyped nil or be the proper non-nil *oracle.Oracle. Otherwise `submitoracleresponse` RPC handler doesn't work properly. Signed-off-by: Anna Shaleva --- cli/server/server.go | 9 ++++++++- pkg/services/rpcsrv/server.go | 14 ++++++++------ pkg/services/rpcsrv/server_helper_test.go | 10 +++++----- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/cli/server/server.go b/cli/server/server.go index b55982a4e..8aba05454 100644 --- a/cli/server/server.go +++ b/cli/server/server.go @@ -359,7 +359,14 @@ func resetDB(ctx *cli.Context) error { return nil } -func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*oracle.Oracle, error) { +// oracleService is an interface representing Oracle service with network.Service +// capabilities and ability to submit oracle responses. +type oracleService interface { + rpcsrv.OracleHandler + network.Service +} + +func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (oracleService, error) { if !config.Enabled { return nil, nil } diff --git a/pkg/services/rpcsrv/server.go b/pkg/services/rpcsrv/server.go index 5fcbf656c..fd1ecf6aa 100644 --- a/pkg/services/rpcsrv/server.go +++ b/pkg/services/rpcsrv/server.go @@ -254,7 +254,8 @@ var invalidBlockHeightError = func(index int, height int) *neorpc.Error { return neorpc.NewRPCError("Invalid block height", fmt.Sprintf("param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height)) } -// New creates a new Server struct. +// New creates a new Server struct. Pay attention that orc is expected to be either +// untyped nil or non-nil structure implementing OracleHandler interface. func New(chain Ledger, conf config.RPC, coreServer *network.Server, orc OracleHandler, log *zap.Logger, errChan chan<- error) Server { addrs := conf.GetAddresses() @@ -293,7 +294,7 @@ func New(chain Ledger, conf config.RPC, coreServer *network.Server, } var oracleWrapped = new(atomic.Value) if orc != nil { - oracleWrapped.Store(&orc) + oracleWrapped.Store(orc) } var wsOriginChecker func(*http.Request) bool if conf.EnableCORSWorkaround { @@ -445,7 +446,7 @@ func (s *Server) Shutdown() { // SetOracleHandler allows to update oracle handler used by the Server. func (s *Server) SetOracleHandler(orc OracleHandler) { - s.oracle.Store(&orc) + s.oracle.Store(orc) } func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { @@ -2461,10 +2462,11 @@ func getRelayResult(err error, hash util.Uint256) (any, *neorpc.Error) { } func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) { - oracle := s.oracle.Load().(*OracleHandler) - if oracle == nil || *oracle == nil { + oraclePtr := s.oracle.Load() + if oraclePtr == nil { return nil, neorpc.NewRPCError("Oracle is not enabled", "") } + oracle := oraclePtr.(OracleHandler) var pub *keys.PublicKey pubBytes, err := ps.Value(0).GetBytesBase64() if err == nil { @@ -2489,7 +2491,7 @@ func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) { if !pub.Verify(msgSig, hash.Sha256(data).BytesBE()) { return nil, neorpc.NewRPCError("Invalid request signature", "") } - (*oracle).AddResponse(pub, uint64(reqID), txSig) + oracle.AddResponse(pub, uint64(reqID), txSig) return json.RawMessage([]byte("{}")), nil } diff --git a/pkg/services/rpcsrv/server_helper_test.go b/pkg/services/rpcsrv/server_helper_test.go index eb11f75d7..893ef1dcb 100644 --- a/pkg/services/rpcsrv/server_helper_test.go +++ b/pkg/services/rpcsrv/server_helper_test.go @@ -29,7 +29,7 @@ const ( notaryPass = "one" ) -func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disableIteratorSessions bool) (*core.Blockchain, *oracle.Oracle, config.Config, *zap.Logger) { +func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disableIteratorSessions bool) (*core.Blockchain, OracleHandler, config.Config, *zap.Logger) { return getUnitTestChainWithCustomConfig(t, enableOracle, enableNotary, func(cfg *config.Config) { if disableIteratorSessions { cfg.ApplicationConfiguration.RPC.SessionEnabled = false @@ -56,7 +56,7 @@ func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disabl } }) } -func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNotary bool, customCfg func(configuration *config.Config)) (*core.Blockchain, *oracle.Oracle, config.Config, *zap.Logger) { +func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNotary bool, customCfg func(configuration *config.Config)) (*core.Blockchain, OracleHandler, config.Config, *zap.Logger) { net := netmode.UnitTestNet configPath := "../../../config" cfg, err := config.Load(configPath, net) @@ -70,7 +70,7 @@ func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNot chain, err := core.NewBlockchain(memoryStore, cfg.Blockchain(), logger) require.NoError(t, err, "could not create chain") - var orc *oracle.Oracle + var orc OracleHandler if enableOracle { orc, err = oracle.NewOracle(oracle.Config{ Log: logger, @@ -79,7 +79,7 @@ func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNot Chain: chain, }) require.NoError(t, err) - chain.SetOracle(orc) + chain.SetOracle(orc.(*oracle.Oracle)) } go chain.Run() @@ -115,7 +115,7 @@ func initClearServerWithServices(t testing.TB, needOracle bool, needNotary bool, return wrapUnitTestChain(t, chain, orc, cfg, logger) } -func wrapUnitTestChain(t testing.TB, chain *core.Blockchain, orc *oracle.Oracle, cfg config.Config, logger *zap.Logger) (*core.Blockchain, *Server, *httptest.Server) { +func wrapUnitTestChain(t testing.TB, chain *core.Blockchain, orc OracleHandler, cfg config.Config, logger *zap.Logger) (*core.Blockchain, *Server, *httptest.Server) { serverConfig, err := network.NewServerConfig(cfg) require.NoError(t, err) serverConfig.UserAgent = fmt.Sprintf(config.UserAgentFormat, "0.98.6-test")