diff --git a/lib/oauthutil/oauthutil.go b/lib/oauthutil/oauthutil.go index 2f3d5278f..7c63e8bf3 100644 --- a/lib/oauthutil/oauthutil.go +++ b/lib/oauthutil/oauthutil.go @@ -706,7 +706,7 @@ func configSetup(ctx context.Context, id, name string, m configmap.Mapper, oauth // Prepare webserver server := newAuthServer(opt, bindAddress, state, authURL) - err = server.Init() + err = server.Init(ctx) if err != nil { return "", fmt.Errorf("failed to start auth webserver: %w", err) } @@ -758,6 +758,7 @@ type authServer struct { authURL string server *http.Server result chan *AuthResult + quit chan struct{} } // newAuthServer makes the webserver for collecting auth @@ -768,6 +769,7 @@ func newAuthServer(opt *Options, bindAddress, state, authURL string) *authServer bindAddress: bindAddress, authURL: authURL, // http://host/auth redirects to here result: make(chan *AuthResult, 1), + quit: make(chan struct{}), } } @@ -830,15 +832,32 @@ func (s *authServer) handleAuth(w http.ResponseWriter, req *http.Request) { } // Init gets the internal web server ready to receive config details -func (s *authServer) Init() error { +// +// The web server will listen until ctx is cancelled or the Stop() +// method is called +func (s *authServer) Init(ctx context.Context) error { fs.Debugf(nil, "Starting auth server on %s", s.bindAddress) mux := http.NewServeMux() s.server = &http.Server{ - Addr: s.bindAddress, - Handler: mux, + Addr: s.bindAddress, + Handler: mux, + BaseContext: func(net.Listener) context.Context { return ctx }, } s.server.SetKeepAlivesEnabled(false) + // Error the server if the context is cancelled + go func() { + select { + case <-ctx.Done(): + s.result <- &AuthResult{ + Name: "Cancelled", + Description: ctx.Err().Error(), + Err: ctx.Err(), + } + case <-s.quit: + } + }() + mux.HandleFunc("/auth", func(w http.ResponseWriter, req *http.Request) { state := req.FormValue("state") if state != s.state { @@ -852,7 +871,8 @@ func (s *authServer) Init() error { mux.HandleFunc("/", s.handleAuth) var err error - s.listener, err = net.Listen("tcp", s.bindAddress) + var lc net.ListenConfig + s.listener, err = lc.Listen(ctx, "tcp", s.bindAddress) if err != nil { return err } @@ -869,6 +889,7 @@ func (s *authServer) Serve() { func (s *authServer) Stop() { fs.Debugf(nil, "Closing auth server") close(s.result) + close(s.quit) _ = s.listener.Close() // close the server