oauthutil: shut down the oauth webserver when the context closes

This patch ensures that we pass the context from the CreateBackend
call all the way to the creation of the oauth web server.

This means that when the command has finished the webserver will
definitely be shut down, and if the user abandons it (eg via an rc
call timing out or being cancelled) then it will be shut down too.
This commit is contained in:
Nick Craig-Wood 2024-10-17 10:04:33 +01:00
parent 175aa07cdd
commit 7c705e0efa

View file

@ -706,7 +706,7 @@ func configSetup(ctx context.Context, id, name string, m configmap.Mapper, oauth
// Prepare webserver
server := newAuthServer(opt, bindAddress, state, authURL)
err = server.Init()
err = server.Init(ctx)
if err != nil {
return "", fmt.Errorf("failed to start auth webserver: %w", err)
}
@ -758,6 +758,7 @@ type authServer struct {
authURL string
server *http.Server
result chan *AuthResult
quit chan struct{}
}
// newAuthServer makes the webserver for collecting auth
@ -768,6 +769,7 @@ func newAuthServer(opt *Options, bindAddress, state, authURL string) *authServer
bindAddress: bindAddress,
authURL: authURL, // http://host/auth redirects to here
result: make(chan *AuthResult, 1),
quit: make(chan struct{}),
}
}
@ -830,15 +832,32 @@ func (s *authServer) handleAuth(w http.ResponseWriter, req *http.Request) {
}
// Init gets the internal web server ready to receive config details
func (s *authServer) Init() error {
//
// The web server will listen until ctx is cancelled or the Stop()
// method is called
func (s *authServer) Init(ctx context.Context) error {
fs.Debugf(nil, "Starting auth server on %s", s.bindAddress)
mux := http.NewServeMux()
s.server = &http.Server{
Addr: s.bindAddress,
Handler: mux,
Addr: s.bindAddress,
Handler: mux,
BaseContext: func(net.Listener) context.Context { return ctx },
}
s.server.SetKeepAlivesEnabled(false)
// Error the server if the context is cancelled
go func() {
select {
case <-ctx.Done():
s.result <- &AuthResult{
Name: "Cancelled",
Description: ctx.Err().Error(),
Err: ctx.Err(),
}
case <-s.quit:
}
}()
mux.HandleFunc("/auth", func(w http.ResponseWriter, req *http.Request) {
state := req.FormValue("state")
if state != s.state {
@ -852,7 +871,8 @@ func (s *authServer) Init() error {
mux.HandleFunc("/", s.handleAuth)
var err error
s.listener, err = net.Listen("tcp", s.bindAddress)
var lc net.ListenConfig
s.listener, err = lc.Listen(ctx, "tcp", s.bindAddress)
if err != nil {
return err
}
@ -869,6 +889,7 @@ func (s *authServer) Serve() {
func (s *authServer) Stop() {
fs.Debugf(nil, "Closing auth server")
close(s.result)
close(s.quit)
_ = s.listener.Close()
// close the server