diff --git a/challenge/http01/http_challenge_server.go b/challenge/http01/http_challenge_server.go index 23af6da1..f69f5ac1 100644 --- a/challenge/http01/http_challenge_server.go +++ b/challenge/http01/http_challenge_server.go @@ -2,9 +2,11 @@ package http01 import ( "fmt" + "io/fs" "net" "net/http" "net/textproto" + "os" "strings" "github.com/go-acme/lego/v4/log" @@ -14,8 +16,11 @@ import ( // It may be instantiated without using the NewProviderServer function if // you want only to use the default values. type ProviderServer struct { - iface string - port string + address string + network string // must be valid argument to net.Listen + + socketMode fs.FileMode + matcher domainMatcher done chan bool listener net.Listener @@ -29,24 +34,34 @@ func NewProviderServer(iface, port string) *ProviderServer { port = "80" } - return &ProviderServer{iface: iface, port: port, matcher: &hostMatcher{}} + return &ProviderServer{network: "tcp", address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}} +} + +func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer { + return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}} } // Present starts a web server and makes the token available at `ChallengePath(token)` for web requests. func (s *ProviderServer) Present(domain, token, keyAuth string) error { var err error - s.listener, err = net.Listen("tcp", s.GetAddress()) + s.listener, err = net.Listen(s.network, s.GetAddress()) if err != nil { return fmt.Errorf("could not start HTTP server for challenge: %w", err) } + if s.network == "unix" { + if err = os.Chmod(s.address, s.socketMode); err != nil { + return fmt.Errorf("chmod %s: %w", s.address, err) + } + } + s.done = make(chan bool) go s.serve(domain, token, keyAuth) return nil } func (s *ProviderServer) GetAddress() string { - return net.JoinHostPort(s.iface, s.port) + return s.address } // CleanUp closes the HTTP server and removes the token from `ChallengePath(token)`. @@ -85,7 +100,7 @@ func (s *ProviderServer) SetProxyHeader(headerName string) { func (s *ProviderServer) serve(domain, token, keyAuth string) { path := ChallengePath(token) - // The incoming request must will be validated to prevent DNS rebind attacks. + // The incoming request will be validated to prevent DNS rebind attacks. // We only respond with the keyAuth, when we're receiving a GET requests with // the "Host" header matching the domain (the latter is configurable though SetProxyHeader). mux := http.NewServeMux() @@ -99,7 +114,7 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) { } log.Infof("[%s] Served key authentication", domain) } else { - log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure your are passing the %s header properly.", r.Host, r.Method, s.matcher.name()) + log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name()) _, err := w.Write([]byte("TEST")) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/challenge/http01/http_challenge_test.go b/challenge/http01/http_challenge_test.go index 9866b2c1..fa42f374 100644 --- a/challenge/http01/http_challenge_test.go +++ b/challenge/http01/http_challenge_test.go @@ -1,12 +1,18 @@ package http01 import ( + "context" "crypto/rand" "crypto/rsa" "fmt" "io" + "io/fs" + "net" "net/http" "net/textproto" + "os" + "path/filepath" + "runtime" "testing" "github.com/go-acme/lego/v4/acme" @@ -17,6 +23,50 @@ import ( "github.com/stretchr/testify/require" ) +func TestProviderServer_GetAddress(t *testing.T) { + dir := t.TempDir() + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + + sock := filepath.Join(dir, "var", "run", "test") + + testCases := []struct { + desc string + server *ProviderServer + expected string + }{ + { + desc: "TCP default address", + server: NewProviderServer("", ""), + expected: ":80", + }, + { + desc: "TCP with explicit port", + server: NewProviderServer("", "8080"), + expected: ":8080", + }, + { + desc: "TCP with host and port", + server: NewProviderServer("localhost", "8080"), + expected: "localhost:8080", + }, + { + desc: "UDS socket", + server: NewUnixProviderServer(sock, fs.ModeSocket|0o666), + expected: sock, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + address := test.server.GetAddress() + assert.Equal(t, test.expected, address) + }) + } +} + func TestChallenge(t *testing.T) { _, apiURL := tester.SetupFakeAPI(t) @@ -69,6 +119,75 @@ func TestChallenge(t *testing.T) { require.NoError(t, err) } +func TestChallengeUnix(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("only for UNIX systems") + } + + _, apiURL := tester.SetupFakeAPI(t) + + dir := t.TempDir() + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + + socket := filepath.Join(dir, "lego-challenge-test.sock") + + providerServer := NewUnixProviderServer(socket, fs.ModeSocket|0o666) + + validate := func(_ *api.Core, _ string, chlng acme.Challenge) error { + // any uri will do, as we hijack the dial + uri := "http://localhost" + ChallengePath(chlng.Token) + + client := &http.Client{Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", socket) + }, + }} + + resp, err := client.Get(uri) + if err != nil { + return err + } + + defer resp.Body.Close() + + if want := "text/plain"; resp.Header.Get("Content-Type") != want { + t.Errorf("Get(%q) Content-Type: got %q, want %q", uri, resp.Header.Get("Content-Type"), want) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + bodyStr := string(body) + + if bodyStr != chlng.KeyAuthorization { + t.Errorf("Get(%q) Body: got %q, want %q", uri, bodyStr, chlng.KeyAuthorization) + } + + return nil + } + + privateKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err, "Could not generate test key") + + core, err := api.New(http.DefaultClient, "lego-test", apiURL+"/dir", "", privateKey) + require.NoError(t, err) + + solver := NewChallenge(core, validate, providerServer) + + authz := acme.Authorization{ + Identifier: acme.Identifier{ + Value: "localhost", + }, + Challenges: []acme.Challenge{ + {Type: challenge.HTTP01.String(), Token: "http1"}, + }, + } + + err = solver.Solve(authz) + require.NoError(t, err) +} + func TestChallengeInvalidPort(t *testing.T) { _, apiURL := tester.SetupFakeAPI(t)