From 9157226e7bf3dbe440fa1563ff662cd3b9b39e34 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 13:39:55 -0400 Subject: [PATCH] Extract request utilities into its own package The RemoteAddr and RemoteIP functions operate on *http.Request values, not contexts. They have very low cohesion with the rest of the package. Signed-off-by: Cory Snider --- internal/dcontext/http.go | 47 +----------- internal/dcontext/http_test.go | 70 ----------------- internal/requestutil/util.go | 51 +++++++++++++ internal/requestutil/util_test.go | 76 +++++++++++++++++++ notifications/bridge.go | 4 +- .../driver/middleware/cloudfront/s3filter.go | 5 +- 6 files changed, 134 insertions(+), 119 deletions(-) create mode 100644 internal/requestutil/util.go create mode 100644 internal/requestutil/util_test.go diff --git a/internal/dcontext/http.go b/internal/dcontext/http.go index 69c29b74..df068f13 100644 --- a/internal/dcontext/http.go +++ b/internal/dcontext/http.go @@ -3,15 +3,14 @@ package dcontext import ( "context" "errors" - "net" "net/http" "strings" "sync" "time" + "github.com/distribution/distribution/v3/internal/requestutil" "github.com/google/uuid" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" ) // Common errors used with this package. @@ -20,48 +19,6 @@ var ( ErrNoResponseWriterContext = errors.New("no http response in context") ) -func parseIP(ipStr string) net.IP { - ip := net.ParseIP(ipStr) - if ip == nil { - log.Warnf("invalid remote IP address: %q", ipStr) - } - return ip -} - -// RemoteAddr extracts the remote address of the request, taking into -// account proxy headers. -func RemoteAddr(r *http.Request) string { - if prior := r.Header.Get("X-Forwarded-For"); prior != "" { - remoteAddr, _, _ := strings.Cut(prior, ",") - remoteAddr = strings.Trim(remoteAddr, " ") - if parseIP(remoteAddr) != nil { - return remoteAddr - } - } - // X-Real-Ip is less supported, but worth checking in the - // absence of X-Forwarded-For - if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { - if parseIP(realIP) != nil { - return realIP - } - } - - return r.RemoteAddr -} - -// RemoteIP extracts the remote IP of the request, taking into -// account proxy headers. -func RemoteIP(r *http.Request) string { - addr := RemoteAddr(r) - - // Try parsing it as "IP:port" - if ip, _, err := net.SplitHostPort(addr); err == nil { - return ip - } - - return addr -} - // WithRequest places the request on the context. The context of the request // is assigned a unique id, available at "http.request.id". The request itself // is available at "http.request". Other common attributes are available under @@ -193,7 +150,7 @@ func (ctx *httpRequestContext) Value(key interface{}) interface{} { case "http.request.uri": return ctx.r.RequestURI case "http.request.remoteaddr": - return RemoteAddr(ctx.r) + return requestutil.RemoteAddr(ctx.r) case "http.request.method": return ctx.r.Method case "http.request.host": diff --git a/internal/dcontext/http_test.go b/internal/dcontext/http_test.go index 9d1069d2..99c47bcd 100644 --- a/internal/dcontext/http_test.go +++ b/internal/dcontext/http_test.go @@ -2,9 +2,6 @@ package dcontext import ( "net/http" - "net/http/httptest" - "net/http/httputil" - "net/url" "reflect" "testing" "time" @@ -219,70 +216,3 @@ func TestWithVars(t *testing.T) { } } } - -// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test -// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten -// at the transport layer to 127.0.0.1: . However, as the X-Forwarded-For header -// just contains the IP address, it is different enough for testing. -func TestRemoteAddr(t *testing.T) { - var expectedRemote string - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - - if r.RemoteAddr == expectedRemote { - t.Errorf("Unexpected matching remote addresses") - } - - actualRemote := RemoteAddr(r) - if expectedRemote != actualRemote { - t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote) - } - - w.WriteHeader(200) - })) - - defer backend.Close() - backendURL, err := url.Parse(backend.URL) - if err != nil { - t.Fatal(err) - } - - proxy := httputil.NewSingleHostReverseProxy(backendURL) - frontend := httptest.NewServer(proxy) - defer frontend.Close() - - // X-Forwarded-For set by proxy - expectedRemote = "127.0.0.1" - proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil) - if err != nil { - t.Fatal(err) - } - - resp, err := http.DefaultClient.Do(proxyReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // RemoteAddr in X-Real-Ip - getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil) - if err != nil { - t.Fatal(err) - } - - expectedRemote = "1.2.3.4" - getReq.Header["X-Real-ip"] = []string{expectedRemote} - resp, err = http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // Valid X-Real-Ip and invalid X-Forwarded-For - getReq.Header["X-forwarded-for"] = []string{"1.2.3"} - resp, err = http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() -} diff --git a/internal/requestutil/util.go b/internal/requestutil/util.go new file mode 100644 index 00000000..099e3454 --- /dev/null +++ b/internal/requestutil/util.go @@ -0,0 +1,51 @@ +package requestutil + +import ( + "net" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +func parseIP(ipStr string) net.IP { + ip := net.ParseIP(ipStr) + if ip == nil { + log.Warnf("invalid remote IP address: %q", ipStr) + } + return ip +} + +// RemoteAddr extracts the remote address of the request, taking into +// account proxy headers. +func RemoteAddr(r *http.Request) string { + if prior := r.Header.Get("X-Forwarded-For"); prior != "" { + remoteAddr, _, _ := strings.Cut(prior, ",") + remoteAddr = strings.Trim(remoteAddr, " ") + if parseIP(remoteAddr) != nil { + return remoteAddr + } + } + // X-Real-Ip is less supported, but worth checking in the + // absence of X-Forwarded-For + if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { + if parseIP(realIP) != nil { + return realIP + } + } + + return r.RemoteAddr +} + +// RemoteIP extracts the remote IP of the request, taking into +// account proxy headers. +func RemoteIP(r *http.Request) string { + addr := RemoteAddr(r) + + // Try parsing it as "IP:port" + if ip, _, err := net.SplitHostPort(addr); err == nil { + return ip + } + + return addr +} diff --git a/internal/requestutil/util_test.go b/internal/requestutil/util_test.go new file mode 100644 index 00000000..fc33527f --- /dev/null +++ b/internal/requestutil/util_test.go @@ -0,0 +1,76 @@ +package requestutil + +import ( + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" +) + +// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test +// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten +// at the transport layer to 127.0.0.1: . However, as the X-Forwarded-For header +// just contains the IP address, it is different enough for testing. +func TestRemoteAddr(t *testing.T) { + var expectedRemote string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if r.RemoteAddr == expectedRemote { + t.Errorf("Unexpected matching remote addresses") + } + + actualRemote := RemoteAddr(r) + if expectedRemote != actualRemote { + t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote) + } + + w.WriteHeader(200) + })) + + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxy := httputil.NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxy) + defer frontend.Close() + + // X-Forwarded-For set by proxy + expectedRemote = "127.0.0.1" + proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(proxyReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // RemoteAddr in X-Real-Ip + getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil) + if err != nil { + t.Fatal(err) + } + + expectedRemote = "1.2.3.4" + getReq.Header["X-Real-ip"] = []string{expectedRemote} + resp, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Valid X-Real-Ip and invalid X-Forwarded-For + getReq.Header["X-forwarded-for"] = []string{"1.2.3"} + resp, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() +} diff --git a/notifications/bridge.go b/notifications/bridge.go index 133153de..8b594774 100644 --- a/notifications/bridge.go +++ b/notifications/bridge.go @@ -5,7 +5,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/internal/dcontext" + "github.com/distribution/distribution/v3/internal/requestutil" "github.com/distribution/reference" events "github.com/docker/go-events" "github.com/google/uuid" @@ -49,7 +49,7 @@ func NewBridge(ub URLBuilder, source SourceRecord, actor ActorRecord, request Re func NewRequestRecord(id string, r *http.Request) RequestRecord { return RequestRecord{ ID: id, - Addr: dcontext.RemoteAddr(r), + Addr: requestutil.RemoteAddr(r), Host: r.Host, Method: r.Method, UserAgent: r.UserAgent(), diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 7a23bcbd..c7ddd6f5 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -12,6 +12,7 @@ import ( "time" "github.com/distribution/distribution/v3/internal/dcontext" + "github.com/distribution/distribution/v3/internal/requestutil" ) const ( @@ -188,7 +189,7 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) { if err != nil { return nil, err } - ipStr := dcontext.RemoteIP(request) + ipStr := requestutil.RemoteIP(request) ip := net.ParseIP(ipStr) if ip == nil { return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) @@ -208,7 +209,7 @@ func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { } else { loggerField := map[interface{}]interface{}{ "user-client": request.UserAgent(), - "ip": dcontext.RemoteIP(request), + "ip": requestutil.RemoteIP(request), } if awsIPs.contains(addr) { dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")