diff --git a/pkg/services/helpers/rpcbroadcaster/broadcaster.go b/pkg/services/helpers/rpcbroadcaster/broadcaster.go index b43225325..e030dcb22 100644 --- a/pkg/services/helpers/rpcbroadcaster/broadcaster.go +++ b/pkg/services/helpers/rpcbroadcaster/broadcaster.go @@ -14,6 +14,7 @@ type RPCBroadcaster struct { Responses chan request.RawParams close chan struct{} + finished chan struct{} sendTimeout time.Duration } @@ -23,6 +24,7 @@ func NewRPCBroadcaster(log *zap.Logger, sendTimeout time.Duration) *RPCBroadcast Clients: make(map[string]*RPCClient), Log: log, close: make(chan struct{}), + finished: make(chan struct{}), Responses: make(chan request.RawParams), sendTimeout: sendTimeout, } @@ -33,10 +35,11 @@ func (r *RPCBroadcaster) Run() { for _, c := range r.Clients { go c.run() } +run: for { select { case <-r.close: - return + break run case ps := <-r.Responses: for _, c := range r.Clients { select { @@ -47,9 +50,31 @@ func (r *RPCBroadcaster) Run() { } } } + for _, c := range r.Clients { + <-c.finished + } +drain: + for { + select { + case <-r.Responses: + default: + break drain + } + } + close(r.Responses) + close(r.finished) +} + +// SendParams sends a request using all clients if the broadcaster is active. +func (r *RPCBroadcaster) SendParams(params request.RawParams) { + select { + case <-r.close: + case r.Responses <- params: + } } // Shutdown implements oracle.Broadcaster. func (r *RPCBroadcaster) Shutdown() { close(r.close) + <-r.finished } diff --git a/pkg/services/helpers/rpcbroadcaster/client.go b/pkg/services/helpers/rpcbroadcaster/client.go index 403c726ea..1045312b0 100644 --- a/pkg/services/helpers/rpcbroadcaster/client.go +++ b/pkg/services/helpers/rpcbroadcaster/client.go @@ -14,6 +14,7 @@ type RPCClient struct { client *client.Client addr string close chan struct{} + finished chan struct{} responses chan request.RawParams log *zap.Logger sendTimeout time.Duration @@ -28,6 +29,7 @@ func (r *RPCBroadcaster) NewRPCClient(addr string, method SendMethod, timeout ti return &RPCClient{ addr: addr, close: r.close, + finished: make(chan struct{}), responses: ch, log: r.Log.With(zap.String("address", addr)), sendTimeout: timeout, @@ -41,10 +43,11 @@ func (c *RPCClient) run() { DialTimeout: c.sendTimeout, RequestTimeout: c.sendTimeout, }) +run: for { select { case <-c.close: - return + break run case ps := <-c.responses: if c.client == nil { var err error @@ -63,4 +66,13 @@ func (c *RPCClient) run() { } } } +drain: + for { + select { + case <-c.responses: + default: + break drain + } + } + close(c.finished) } diff --git a/pkg/services/oracle/broadcaster/oracle.go b/pkg/services/oracle/broadcaster/oracle.go index 29dd4a160..8a8d4c328 100644 --- a/pkg/services/oracle/broadcaster/oracle.go +++ b/pkg/services/oracle/broadcaster/oracle.go @@ -51,7 +51,7 @@ func (r *oracleBroadcaster) SendResponse(priv *keys.PrivateKey, resp *transactio base64.StdEncoding.EncodeToString(txSig), base64.StdEncoding.EncodeToString(msgSig), ) - r.Responses <- params + r.SendParams(params) } // GetMessage returns data which is signed upon sending response by RPC.