From 70902b4051d9eda30d3ac02c541d9d8cab45a556 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Tue, 23 Feb 2016 21:16:13 +0000 Subject: [PATCH] Make rest Set methods safe for concurrent calling --- rest/rest.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/rest/rest.go b/rest/rest.go index 20ac23a3c..7acf35fed 100644 --- a/rest/rest.go +++ b/rest/rest.go @@ -1,4 +1,6 @@ // Package rest implements a simple REST wrapper +// +// All methods are safe for concurrent calling. package rest import ( @@ -8,12 +10,14 @@ import ( "io" "io/ioutil" "net/http" + "sync" "github.com/ncw/rclone/fs" ) // Client contains the info to sustain the API type Client struct { + mu sync.RWMutex c *http.Client rootURL string errorHandler func(resp *http.Response) error @@ -45,18 +49,24 @@ func defaultErrorHandler(resp *http.Response) (err error) { // SetErrorHandler sets the handler to decode an error response when // the HTTP status code is not 2xx. The handler should close resp.Body. func (api *Client) SetErrorHandler(fn func(resp *http.Response) error) *Client { + api.mu.Lock() + defer api.mu.Unlock() api.errorHandler = fn return api } // SetRoot sets the default root URL func (api *Client) SetRoot(RootURL string) *Client { + api.mu.Lock() + defer api.mu.Unlock() api.rootURL = RootURL return api } // SetHeader sets a header for all requests func (api *Client) SetHeader(key, value string) *Client { + api.mu.Lock() + defer api.mu.Unlock() api.headers[key] = value return api } @@ -89,6 +99,8 @@ func DecodeJSON(resp *http.Response, result interface{}) (err error) { // // it will return resp if at all possible, even if err is set func (api *Client) Call(opts *Opts) (resp *http.Response, err error) { + api.mu.RLock() + defer api.mu.RUnlock() if opts == nil { return nil, fmt.Errorf("call() called with nil opts") } @@ -127,12 +139,16 @@ func (api *Client) Call(opts *Opts) (resp *http.Response, err error) { } // Now set the headers for k, v := range headers { - req.Header.Add(k, v) + if v != "" { + req.Header.Add(k, v) + } } if opts.UserName != "" || opts.Password != "" { req.SetBasicAuth(opts.UserName, opts.Password) } + api.mu.RUnlock() resp, err = api.c.Do(req) + api.mu.RLock() if err != nil { return nil, err }