diff --git a/cli/server/server_test.go b/cli/server/server_test.go index f13d4664b..529bb0878 100644 --- a/cli/server/server_test.go +++ b/cli/server/server_test.go @@ -330,7 +330,9 @@ func TestConfigureAddresses(t *testing.T) { cfg := &config.ApplicationConfiguration{ Address: defaultAddress, RPC: config.RPC{ - Address: customAddress, + BasicService: config.BasicService{ + Address: customAddress, + }, }, } configureAddresses(cfg) diff --git a/internal/testserdes/testing.go b/internal/testserdes/testing.go index 5bc7ad6d5..3cec4801e 100644 --- a/internal/testserdes/testing.go +++ b/internal/testserdes/testing.go @@ -7,6 +7,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) // MarshalUnmarshalJSON checks if the expected stays the same after @@ -18,6 +19,15 @@ func MarshalUnmarshalJSON(t *testing.T, expected, actual interface{}) { require.Equal(t, expected, actual) } +// MarshalUnmarshalYAML checks if the expected stays the same after +// marshal/unmarshal via YAML. +func MarshalUnmarshalYAML(t *testing.T, expected, actual interface{}) { + data, err := yaml.Marshal(expected) + require.NoError(t, err) + require.NoError(t, yaml.Unmarshal(data, actual)) + require.Equal(t, expected, actual) +} + // EncodeDecodeBinary checks if the expected stays the same after // serializing/deserializing via io.Serializable methods. func EncodeDecodeBinary(t *testing.T, expected, actual io.Serializable) { diff --git a/pkg/config/application_config_test.go b/pkg/config/application_config_test.go index 8b5428e8f..a675ae65b 100644 --- a/pkg/config/application_config_test.go +++ b/pkg/config/application_config_test.go @@ -20,3 +20,12 @@ func TestApplicationConfigurationEquals(t *testing.T) { require.NoError(t, err) require.False(t, cfg1.ApplicationConfiguration.EqualsButServices(&cfg2.ApplicationConfiguration)) } + +// TestApplicationConfiguration_UnmarshalRPCBasicService is aimed to check that BasicService +// config of RPC service can be properly unmarshalled. +func TestApplicationConfiguration_UnmarshalRPCBasicService(t *testing.T) { + cfg, err := LoadFile(filepath.Join("..", "..", "config", "protocol.mainnet.yml")) + require.NoError(t, err) + require.True(t, cfg.ApplicationConfiguration.RPC.Enabled) + require.Equal(t, uint16(10332), cfg.ApplicationConfiguration.RPC.Port) +} diff --git a/pkg/config/basic_service.go b/pkg/config/basic_service.go index 3e2137b96..71840ed36 100644 --- a/pkg/config/basic_service.go +++ b/pkg/config/basic_service.go @@ -1,8 +1,19 @@ package config -// BasicService is used for simple services like Pprof or Prometheus monitoring. +import ( + "net" + "strconv" +) + +// BasicService is used as a simple base for node services like Pprof, RPC or +// Prometheus monitoring. type BasicService struct { Enabled bool `yaml:"Enabled"` Address string `yaml:"Address"` - Port string `yaml:"Port"` + Port uint16 `yaml:"Port"` +} + +// FormatAddress returns the full service's address in the form of "address:port". +func (s BasicService) FormatAddress() string { + return net.JoinHostPort(s.Address, strconv.FormatUint(uint64(s.Port), 10)) } diff --git a/pkg/config/basic_service_test.go b/pkg/config/basic_service_test.go new file mode 100644 index 000000000..7661bca69 --- /dev/null +++ b/pkg/config/basic_service_test.go @@ -0,0 +1,19 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBasicService_FormatAddress(t *testing.T) { + for expected, tc := range map[string]BasicService{ + "localhost:10332": {Address: "localhost", Port: 10332}, + "127.0.0.1:0": {Address: "127.0.0.1"}, + ":0": {}, + } { + t.Run(expected, func(t *testing.T) { + require.Equal(t, expected, tc.FormatAddress()) + }) + } +} diff --git a/pkg/config/rpc_config.go b/pkg/config/rpc_config.go index 4411a7fee..44df2aa38 100644 --- a/pkg/config/rpc_config.go +++ b/pkg/config/rpc_config.go @@ -7,9 +7,8 @@ import ( type ( // RPC is an RPC service configuration information. RPC struct { - Address string `yaml:"Address"` - Enabled bool `yaml:"Enabled"` - EnableCORSWorkaround bool `yaml:"EnableCORSWorkaround"` + BasicService `yaml:",inline"` + EnableCORSWorkaround bool `yaml:"EnableCORSWorkaround"` // MaxGasInvoke is the maximum amount of GAS which // can be spent during an RPC call. MaxGasInvoke fixedn.Fixed8 `yaml:"MaxGasInvoke"` @@ -17,7 +16,6 @@ type ( MaxFindResultItems int `yaml:"MaxFindResultItems"` MaxNEP11Tokens int `yaml:"MaxNEP11Tokens"` MaxWebSocketClients int `yaml:"MaxWebSocketClients"` - Port uint16 `yaml:"Port"` SessionEnabled bool `yaml:"SessionEnabled"` SessionExpirationTime int `yaml:"SessionExpirationTime"` SessionBackedByMPT bool `yaml:"SessionBackedByMPT"` @@ -28,10 +26,8 @@ type ( // TLS describes SSL/TLS configuration. TLS struct { - Address string `yaml:"Address"` - CertFile string `yaml:"CertFile"` - Enabled bool `yaml:"Enabled"` - Port uint16 `yaml:"Port"` - KeyFile string `yaml:"KeyFile"` + BasicService `yaml:",inline"` + CertFile string `yaml:"CertFile"` + KeyFile string `yaml:"KeyFile"` } ) diff --git a/pkg/services/metrics/pprof.go b/pkg/services/metrics/pprof.go index 6c3bbb202..343982bdc 100644 --- a/pkg/services/metrics/pprof.go +++ b/pkg/services/metrics/pprof.go @@ -26,7 +26,7 @@ func NewPprofService(cfg config.BasicService, log *zap.Logger) *Service { return &Service{ Server: &http.Server{ - Addr: cfg.Address + ":" + cfg.Port, + Addr: cfg.FormatAddress(), Handler: handler, }, config: cfg, diff --git a/pkg/services/metrics/prometheus.go b/pkg/services/metrics/prometheus.go index b00efc8b3..209c13d7c 100644 --- a/pkg/services/metrics/prometheus.go +++ b/pkg/services/metrics/prometheus.go @@ -19,7 +19,7 @@ func NewPrometheusService(cfg config.BasicService, log *zap.Logger) *Service { return &Service{ Server: &http.Server{ - Addr: cfg.Address + ":" + cfg.Port, + Addr: cfg.FormatAddress(), Handler: promhttp.Handler(), }, config: cfg, diff --git a/pkg/services/rpcsrv/server.go b/pkg/services/rpcsrv/server.go index 1a1a37ea3..98af82b01 100644 --- a/pkg/services/rpcsrv/server.go +++ b/pkg/services/rpcsrv/server.go @@ -255,13 +255,13 @@ var invalidBlockHeightError = func(index int, height int) *neorpc.Error { func New(chain Ledger, conf config.RPC, coreServer *network.Server, orc OracleHandler, log *zap.Logger, errChan chan error) Server { httpServer := &http.Server{ - Addr: conf.Address + ":" + strconv.FormatUint(uint64(conf.Port), 10), + Addr: conf.FormatAddress(), } var tlsServer *http.Server if cfg := conf.TLSConfig; cfg.Enabled { tlsServer = &http.Server{ - Addr: net.JoinHostPort(cfg.Address, strconv.FormatUint(uint64(cfg.Port), 10)), + Addr: cfg.FormatAddress(), } }