diff --git a/lib/oauthutil/oauthutil.go b/lib/oauthutil/oauthutil.go index cb3137baa..9ac665f5e 100644 --- a/lib/oauthutil/oauthutil.go +++ b/lib/oauthutil/oauthutil.go @@ -262,16 +262,22 @@ func NewClient(name string, oauthConfig *oauth2.Config) (*http.Client, *TokenSou // // It may run an internal webserver to receive the results func Config(id, name string, config *oauth2.Config, opts ...oauth2.AuthCodeOption) error { - return doConfig(id, name, config, true, opts) + return doConfig(id, name, nil, config, true, opts) } // ConfigNoOffline does the same as Config but does not pass the // "access_type=offline" parameter. func ConfigNoOffline(id, name string, config *oauth2.Config, opts ...oauth2.AuthCodeOption) error { - return doConfig(id, name, config, false, opts) + return doConfig(id, name, nil, config, false, opts) } -func doConfig(id, name string, oauthConfig *oauth2.Config, offline bool, opts []oauth2.AuthCodeOption) error { +// ConfigErrorCheck does the same as Config, but allows the backend to pass a error handling function +// This function gets called with the request made to rclone as a parameter if no code was found +func ConfigErrorCheck(id, name string, errorHandler func(*http.Request) AuthError, config *oauth2.Config, opts ...oauth2.AuthCodeOption) error { + return doConfig(id, name, errorHandler, config, true, opts) +} + +func doConfig(id, name string, errorHandler func(*http.Request) AuthError, oauthConfig *oauth2.Config, offline bool, opts []oauth2.AuthCodeOption) error { oauthConfig, changed := overrideCredentials(name, oauthConfig) automatic := config.FileGet(name, config.ConfigAutomatic) != "" @@ -350,12 +356,14 @@ func doConfig(id, name string, oauthConfig *oauth2.Config, offline bool, opts [] // Prepare webserver server := authServer{ - state: state, - bindAddress: bindAddress, - authURL: authURL, + state: state, + bindAddress: bindAddress, + authURL: authURL, + errorHandler: errorHandler, } if useWebServer { server.code = make(chan string, 1) + server.err = make(chan error, 1) go server.Start() defer server.Stop() authURL = "http://" + bindAddress + "/auth" @@ -371,9 +379,13 @@ func doConfig(id, name string, oauthConfig *oauth2.Config, offline bool, opts [] // Read the code, and exchange it for a token. fmt.Printf("Waiting for code...\n") authCode = <-server.code + authError := <-server.err if authCode != "" { fmt.Printf("Got code\n") } else { + if authError != nil { + return authError + } return errors.New("failed to get code") } } else { @@ -399,12 +411,22 @@ func doConfig(id, name string, oauthConfig *oauth2.Config, offline bool, opts [] // Local web server for collecting auth type authServer struct { - state string - listener net.Listener - bindAddress string - code chan string - authURL string - server *http.Server + state string + listener net.Listener + bindAddress string + code chan string + err chan error + authURL string + server *http.Server + errorHandler func(*http.Request) AuthError +} + +// AuthError gets returned by the backend's errorHandler function +type AuthError struct { + Name string + Description string + Code string + HelpURL string } // startWebServer runs an internal web server to receive config details @@ -428,6 +450,7 @@ func (s *authServer) Start() { w.Header().Set("Content-Type", "text/html") fs.Debugf(nil, "Received request on auth server") code := req.FormValue("code") + var err error if code != "" { state := req.FormValue("state") if state != s.state { @@ -443,11 +466,20 @@ func (s *authServer) Start() { } } else { fs.Debugf(nil, "No code found on request") + httpResponse := "

Failed!

\nNo code found returned by remote server." + if s.errorHandler != nil { + authError := s.errorHandler(req) + errorDesc := fmt.Sprintf("Error: %s\nCode: %s\nDescription: %s\nHelp: %s", + authError.Name, authError.Code, authError.Description, authError.HelpURL) + httpResponse += "\n

" + strings.Replace(errorDesc, "\n", "
", -1) + "

" + err = errors.New(errorDesc) + } w.WriteHeader(500) - _, _ = fmt.Fprintf(w, "

Failed!

\nNo code found returned by remote server.") + _, _ = fmt.Fprintf(w, httpResponse) } if s.code != nil { s.code <- code + s.err <- err } })