diff --git a/amazonclouddrive/amazonclouddrive.go b/amazonclouddrive/amazonclouddrive.go index 175f051c0..f7c251851 100644 --- a/amazonclouddrive/amazonclouddrive.go +++ b/amazonclouddrive/amazonclouddrive.go @@ -84,12 +84,13 @@ func init() { // Fs represents a remote acd server type Fs struct { - name string // name of this remote - c *acd.Client // the connection to the acd server - noAuthClient *http.Client // unauthenticated http client - root string // the path we are working on - dirCache *dircache.DirCache // Map of directory path to directory id - pacer *pacer.Pacer // pacer for API calls + name string // name of this remote + c *acd.Client // the connection to the acd server + noAuthClient *http.Client // unauthenticated http client + root string // the path we are working on + dirCache *dircache.DirCache // Map of directory path to directory id + pacer *pacer.Pacer // pacer for API calls + ts *oauthutil.TokenSource // token source for oauth } // Object describes a acd object @@ -140,14 +141,19 @@ var retryErrorCodes = []int{ // shouldRetry returns a boolean as to whether this resp and err // deserve to be retried. It returns the err as a convenience -func shouldRetry(resp *http.Response, err error) (bool, error) { +func (f *Fs) shouldRetry(resp *http.Response, err error) (bool, error) { + if resp != nil && resp.StatusCode == 401 { + f.ts.Invalidate() + fs.Log(f, "401 error received - invalidating token") + return true, err + } return fs.ShouldRetry(err) || fs.ShouldRetryHTTP(resp, retryErrorCodes), err } // NewFs constructs an Fs from the path, container:path func NewFs(name, root string) (fs.Fs, error) { root = parsePath(root) - oAuthClient, err := oauthutil.NewClient(name, acdConfig) + oAuthClient, ts, err := oauthutil.NewClient(name, acdConfig) if err != nil { log.Fatalf("Failed to configure amazon cloud drive: %v", err) } @@ -160,13 +166,14 @@ func NewFs(name, root string) (fs.Fs, error) { c: c, pacer: pacer.New().SetMinSleep(minSleep).SetPacer(pacer.AmazonCloudDrivePacer), noAuthClient: fs.Config.Client(), + ts: ts, } // Update endpoints var resp *http.Response err = f.pacer.Call(func() (bool, error) { _, resp, err = f.c.Account.GetEndpoints() - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil { return nil, fmt.Errorf("Failed to get endpoints: %v", err) @@ -176,7 +183,7 @@ func NewFs(name, root string) (fs.Fs, error) { var rootInfo *acd.Folder err = f.pacer.Call(func() (bool, error) { rootInfo, resp, err = f.c.Nodes.GetRoot() - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil || rootInfo.Id == nil { return nil, fmt.Errorf("Failed to get root: %v", err) @@ -245,7 +252,7 @@ func (f *Fs) FindLeaf(pathID, leaf string) (pathIDOut string, found bool, err er var subFolder *acd.Folder err = f.pacer.Call(func() (bool, error) { subFolder, resp, err = folder.GetFolder(leaf) - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil { if err == acd.ErrorNodeNotFound { @@ -272,7 +279,7 @@ func (f *Fs) CreateDir(pathID, leaf string) (newID string, err error) { var info *acd.Folder err = f.pacer.Call(func() (bool, error) { info, resp, err = folder.CreateFolder(leaf) - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil { //fmt.Printf("...Error %v\n", err) @@ -314,7 +321,7 @@ func (f *Fs) listAll(dirID string, title string, directoriesOnly bool, filesOnly var resp *http.Response err = f.pacer.CallNoRetry(func() (bool, error) { nodes, resp, err = f.c.Nodes.GetNodes(&opts) - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil { return false, err @@ -424,7 +431,7 @@ func (f *Fs) Put(in io.Reader, src fs.ObjectInfo) (fs.Object, error) { } else { info, resp, err = folder.PutSized(in, size, leaf) } - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil { return nil, err @@ -479,7 +486,7 @@ func (f *Fs) purgeCheck(check bool) error { var resp *http.Response err = f.pacer.Call(func() (bool, error) { resp, err = node.Trash() - return shouldRetry(resp, err) + return f.shouldRetry(resp, err) }) if err != nil { return err @@ -593,7 +600,7 @@ func (o *Object) readMetaData() (err error) { var info *acd.File err = o.fs.pacer.Call(func() (bool, error) { info, resp, err = folder.GetFile(leaf) - return shouldRetry(resp, err) + return o.fs.shouldRetry(resp, err) }) if err != nil { fs.Debug(o, "Failed to read info: %s", err) @@ -647,7 +654,7 @@ func (o *Object) Open() (in io.ReadCloser, err error) { } else { in, resp, err = file.OpenTempURL(o.fs.noAuthClient) } - return shouldRetry(resp, err) + return o.fs.shouldRetry(resp, err) }) return in, err } @@ -667,7 +674,7 @@ func (o *Object) Update(in io.Reader, src fs.ObjectInfo) error { } else { info, resp, err = file.Overwrite(in) } - return shouldRetry(resp, err) + return o.fs.shouldRetry(resp, err) }) if err != nil { return err @@ -682,7 +689,7 @@ func (o *Object) Remove() error { var err error err = o.fs.pacer.Call(func() (bool, error) { resp, err = o.info.Trash() - return shouldRetry(resp, err) + return o.fs.shouldRetry(resp, err) }) return err } diff --git a/drive/drive.go b/drive/drive.go index bbd586156..c18deea30 100644 --- a/drive/drive.go +++ b/drive/drive.go @@ -278,7 +278,7 @@ func NewFs(name, path string) (fs.Fs, error) { return nil, fmt.Errorf("drive: chunk size can't be less than 256k - was %v", chunkSize) } - oAuthClient, err := oauthutil.NewClient(name, driveConfig) + oAuthClient, _, err := oauthutil.NewClient(name, driveConfig) if err != nil { log.Fatalf("Failed to configure drive: %v", err) } diff --git a/googlecloudstorage/googlecloudstorage.go b/googlecloudstorage/googlecloudstorage.go index e77ff0862..f02ff381b 100644 --- a/googlecloudstorage/googlecloudstorage.go +++ b/googlecloudstorage/googlecloudstorage.go @@ -215,7 +215,7 @@ func NewFs(name, root string) (fs.Fs, error) { log.Fatalf("Failed configuring Google Cloud Storage Service Account: %v", err) } } else { - oAuthClient, err = oauthutil.NewClient(name, storageConfig) + oAuthClient, _, err = oauthutil.NewClient(name, storageConfig) if err != nil { log.Fatalf("Failed to configure Google Cloud Storage: %v", err) } diff --git a/hubic/hubic.go b/hubic/hubic.go index 3ea3584e5..7668ab769 100644 --- a/hubic/hubic.go +++ b/hubic/hubic.go @@ -142,7 +142,7 @@ func (f *Fs) getCredentials() (err error) { // NewFs constructs an Fs from the path, container:path func NewFs(name, root string) (fs.Fs, error) { - client, err := oauthutil.NewClient(name, oauthConfig) + client, _, err := oauthutil.NewClient(name, oauthConfig) if err != nil { return nil, fmt.Errorf("Failed to configure Hubic: %v", err) } diff --git a/oauthutil/oauthutil.go b/oauthutil/oauthutil.go index c16d94277..56212ae05 100644 --- a/oauthutil/oauthutil.go +++ b/oauthutil/oauthutil.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strings" + "sync" "time" "github.com/ncw/rclone/fs" @@ -104,11 +105,14 @@ func putToken(name string, token *oauth2.Token) error { return nil } -// tokenSource stores updated tokens in the config file -type tokenSource struct { - Name string - TokenSource oauth2.TokenSource - OldToken oauth2.Token +// TokenSource stores updated tokens in the config file +type TokenSource struct { + mu sync.Mutex + name string + tokenSource oauth2.TokenSource + token *oauth2.Token + config *oauth2.Config + ctx context.Context } // Token returns a token or an error. @@ -116,13 +120,23 @@ type tokenSource struct { // The returned Token must not be modified. // // This saves the token in the config file if it has changed -func (ts *tokenSource) Token() (*oauth2.Token, error) { - token, err := ts.TokenSource.Token() +func (ts *TokenSource) Token() (*oauth2.Token, error) { + ts.mu.Lock() + defer ts.mu.Unlock() + + // Make a new token source if required + if ts.tokenSource == nil { + ts.tokenSource = ts.config.TokenSource(ts.ctx, ts.token) + } + + token, err := ts.tokenSource.Token() if err != nil { return nil, err } - if *token != ts.OldToken { - err = putToken(ts.Name, token) + changed := *token != *ts.token + ts.token = token + if changed { + err = putToken(ts.name, token) if err != nil { return nil, err } @@ -130,8 +144,15 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) { return token, nil } +// Invalidate invalidates the token +func (ts *TokenSource) Invalidate() { + ts.mu.Lock() + ts.token.AccessToken = "" + ts.mu.Unlock() +} + // Check interface satisfied -var _ oauth2.TokenSource = (*tokenSource)(nil) +var _ oauth2.TokenSource = (*TokenSource)(nil) // Context returns a context with our HTTP Client baked in for oauth2 func Context() context.Context { @@ -157,12 +178,12 @@ func overrideCredentials(name string, config *oauth2.Config) bool { } // NewClient gets a token from the config file and configures a Client -// with it -func NewClient(name string, config *oauth2.Config) (*http.Client, error) { +// with it. It returns the client and a TokenSource which Invalidate may need to be called on +func NewClient(name string, config *oauth2.Config) (*http.Client, *TokenSource, error) { overrideCredentials(name, config) token, err := getToken(name) if err != nil { - return nil, err + return nil, nil, err } // Set our own http client in the context @@ -170,12 +191,13 @@ func NewClient(name string, config *oauth2.Config) (*http.Client, error) { // Wrap the TokenSource in our TokenSource which saves changed // tokens in the config file - ts := &tokenSource{ - Name: name, - OldToken: *token, - TokenSource: config.TokenSource(ctx, token), + ts := &TokenSource{ + name: name, + token: token, + config: config, + ctx: ctx, } - return oauth2.NewClient(ctx, ts), nil + return oauth2.NewClient(ctx, ts), ts, nil } diff --git a/onedrive/onedrive.go b/onedrive/onedrive.go index 1c3d44b95..f4ceed2b0 100644 --- a/onedrive/onedrive.go +++ b/onedrive/onedrive.go @@ -170,7 +170,7 @@ func errorHandler(resp *http.Response) error { // NewFs constructs an Fs from the path, container:path func NewFs(name, root string) (fs.Fs, error) { root = parsePath(root) - oAuthClient, err := oauthutil.NewClient(name, oauthConfig) + oAuthClient, _, err := oauthutil.NewClient(name, oauthConfig) if err != nil { log.Fatalf("Failed to configure One Drive: %v", err) }