rpcbroadcaster: properly stop broadcaster

Drain channels, wait for everything to stop.
This commit is contained in:
Roman Khimov 2022-07-01 23:04:54 +03:00
parent c096f32a32
commit 649fe58550
3 changed files with 40 additions and 3 deletions

View file

@ -14,6 +14,7 @@ type RPCBroadcaster struct {
Responses chan request.RawParams Responses chan request.RawParams
close chan struct{} close chan struct{}
finished chan struct{}
sendTimeout time.Duration sendTimeout time.Duration
} }
@ -23,6 +24,7 @@ func NewRPCBroadcaster(log *zap.Logger, sendTimeout time.Duration) *RPCBroadcast
Clients: make(map[string]*RPCClient), Clients: make(map[string]*RPCClient),
Log: log, Log: log,
close: make(chan struct{}), close: make(chan struct{}),
finished: make(chan struct{}),
Responses: make(chan request.RawParams), Responses: make(chan request.RawParams),
sendTimeout: sendTimeout, sendTimeout: sendTimeout,
} }
@ -33,10 +35,11 @@ func (r *RPCBroadcaster) Run() {
for _, c := range r.Clients { for _, c := range r.Clients {
go c.run() go c.run()
} }
run:
for { for {
select { select {
case <-r.close: case <-r.close:
return break run
case ps := <-r.Responses: case ps := <-r.Responses:
for _, c := range r.Clients { for _, c := range r.Clients {
select { select {
@ -47,9 +50,31 @@ func (r *RPCBroadcaster) Run() {
} }
} }
} }
for _, c := range r.Clients {
<-c.finished
}
drain:
for {
select {
case <-r.Responses:
default:
break drain
}
}
close(r.Responses)
close(r.finished)
}
// SendParams sends a request using all clients if the broadcaster is active.
func (r *RPCBroadcaster) SendParams(params request.RawParams) {
select {
case <-r.close:
case r.Responses <- params:
}
} }
// Shutdown implements oracle.Broadcaster. // Shutdown implements oracle.Broadcaster.
func (r *RPCBroadcaster) Shutdown() { func (r *RPCBroadcaster) Shutdown() {
close(r.close) close(r.close)
<-r.finished
} }

View file

@ -14,6 +14,7 @@ type RPCClient struct {
client *client.Client client *client.Client
addr string addr string
close chan struct{} close chan struct{}
finished chan struct{}
responses chan request.RawParams responses chan request.RawParams
log *zap.Logger log *zap.Logger
sendTimeout time.Duration sendTimeout time.Duration
@ -28,6 +29,7 @@ func (r *RPCBroadcaster) NewRPCClient(addr string, method SendMethod, timeout ti
return &RPCClient{ return &RPCClient{
addr: addr, addr: addr,
close: r.close, close: r.close,
finished: make(chan struct{}),
responses: ch, responses: ch,
log: r.Log.With(zap.String("address", addr)), log: r.Log.With(zap.String("address", addr)),
sendTimeout: timeout, sendTimeout: timeout,
@ -41,10 +43,11 @@ func (c *RPCClient) run() {
DialTimeout: c.sendTimeout, DialTimeout: c.sendTimeout,
RequestTimeout: c.sendTimeout, RequestTimeout: c.sendTimeout,
}) })
run:
for { for {
select { select {
case <-c.close: case <-c.close:
return break run
case ps := <-c.responses: case ps := <-c.responses:
if c.client == nil { if c.client == nil {
var err error var err error
@ -63,4 +66,13 @@ func (c *RPCClient) run() {
} }
} }
} }
drain:
for {
select {
case <-c.responses:
default:
break drain
}
}
close(c.finished)
} }

View file

@ -51,7 +51,7 @@ func (r *oracleBroadcaster) SendResponse(priv *keys.PrivateKey, resp *transactio
base64.StdEncoding.EncodeToString(txSig), base64.StdEncoding.EncodeToString(txSig),
base64.StdEncoding.EncodeToString(msgSig), base64.StdEncoding.EncodeToString(msgSig),
) )
r.Responses <- params r.SendParams(params)
} }
// GetMessage returns data which is signed upon sending response by RPC. // GetMessage returns data which is signed upon sending response by RPC.