storage/driver: replace URLFor method

Several storage drivers and storage middlewares need to introspect the
client HTTP request in order to construct content-redirect URLs. The
request is indirectly passed into the driver interface method URLFor()
through the context argument, which is bad practice. The request should
be passed in as an explicit argument as the method is only called from
request handlers.

Replace the URLFor() method with a RedirectURL() method which takes an
HTTP request as a parameter instead of a context. Drop the options
argument from URLFor() as in practice it only ever encoded the request
method, which can now be fetched directly from the request. No URLFor()
callers ever passed in an "expiry" option, either.

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider 2023-10-24 15:49:47 -04:00
parent 868faeec67
commit f089932de0
16 changed files with 111 additions and 174 deletions

View file

@ -20,7 +20,7 @@ type blobServer struct {
driver driver.StorageDriver
statter distribution.BlobStatter
pathFn func(dgst digest.Digest) (string, error)
redirect bool // allows disabling URLFor redirects
redirect bool // allows disabling RedirectURL redirects
}
func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
@ -35,19 +35,16 @@ func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *h
}
if bs.redirect {
redirectURL, err := bs.driver.URLFor(ctx, path, map[string]interface{}{"method": r.Method})
switch err.(type) {
case nil:
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return err
case driver.ErrUnsupportedMethod:
// Fallback to serving the content directly.
default:
// Some unexpected error.
redirectURL, err := bs.driver.RedirectURL(r, path)
if err != nil {
return err
}
if redirectURL != "" {
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return nil
}
// Fallback to serving the content directly.
}
br, err := newFileReader(ctx, bs.driver, path, desc.Size)

View file

@ -8,6 +8,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
@ -286,7 +287,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil)
sourceBlobURL, err := d.signBlobURL(ctx, sourcePath)
if err != nil {
return err
}
@ -366,18 +367,15 @@ func (d *driver) Delete(ctx context.Context, path string) error {
return nil
}
// URLFor returns a publicly accessible URL for the blob stored at given path
// RedirectURL returns a publicly accessible URL for the blob stored at given path
// for specified duration by making use of Azure Storage Shared Access Signatures (SAS).
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
func (d *driver) RedirectURL(req *http.Request, path string) (string, error) {
return d.signBlobURL(req.Context(), path)
}
func (d *driver) signBlobURL(ctx context.Context, path string) (string, error) {
expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration
expires, ok := options["expiry"]
if ok {
t, ok := expires.(time.Time)
if ok {
expiresTime = t
}
}
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime)

View file

@ -40,6 +40,7 @@ package base
import (
"context"
"io"
"net/http"
"time"
"github.com/distribution/distribution/v3/internal/dcontext"
@ -208,18 +209,18 @@ func (base *Base) Delete(ctx context.Context, path string) error {
return err
}
// URLFor wraps URLFor of underlying storage driver.
func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.URLFor(%q)", base.Name(), path)
// RedirectURL wraps RedirectURL of the underlying storage driver.
func (base *Base) RedirectURL(r *http.Request, path string) (string, error) {
ctx, done := dcontext.WithTrace(r.Context())
defer done("%s.RedirectURL(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
start := time.Now()
str, e := base.StorageDriver.URLFor(ctx, path, options)
storageAction.WithValues(base.Name(), "URLFor").UpdateSince(start)
str, e := base.StorageDriver.RedirectURL(r.WithContext(ctx), path)
storageAction.WithValues(base.Name(), "RedirectURL").UpdateSince(start)
return str, base.setDriverName(e)
}

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"sync"
@ -172,13 +173,11 @@ func (r *regulator) Delete(ctx context.Context, path string) error {
return r.StorageDriver.Delete(ctx, path)
}
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
// RedirectURL returns a URL which may be used to retrieve the content stored at
// the given path.
func (r *regulator) RedirectURL(req *http.Request, path string) (string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.URLFor(ctx, path, options)
return r.StorageDriver.RedirectURL(req, path)
}

View file

@ -6,6 +6,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"path"
"time"
@ -282,10 +283,9 @@ func (d *driver) Delete(ctx context.Context, subPath string) error {
return err
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
// RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
func (d *driver) RedirectURL(*http.Request, string) (string, error) {
return "", nil
}
// Walk traverses a filesystem defined within driver, starting

View file

@ -810,40 +810,24 @@ func storageCopyObject(ctx context.Context, srcBucket, srcName string, destBucke
return attrs, err
}
// URLFor returns a URL which may be used to retrieve the content stored at
// RedirectURL returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// Returns ErrUnsupportedMethod if this driver has no privateKey
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
func (d *driver) RedirectURL(r *http.Request, path string) (string, error) {
if d.privateKey == nil {
return "", storagedriver.ErrUnsupportedMethod{}
return "", nil
}
name := d.pathToKey(path)
methodString := http.MethodGet
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresTime := time.Now().Add(20 * time.Minute)
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresTime = et
}
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return "", nil
}
opts := &storage.SignedURLOptions{
GoogleAccessID: d.email,
PrivateKey: d.privateKey,
Method: methodString,
Expires: expiresTime,
Method: r.Method,
Expires: time.Now().Add(20 * time.Minute),
}
return storage.SignedURL(d.bucket, name, opts)
return storage.SignedURL(d.bucket, d.pathToKey(path), opts)
}
// Walk traverses a filesystem defined within driver, starting

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"sync"
"time"
@ -236,10 +237,9 @@ func (d *driver) Delete(ctx context.Context, path string) error {
}
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
// RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
func (d *driver) RedirectURL(*http.Request, string) (string, error) {
return "", nil
}
// Walk traverses a filesystem defined within driver, starting

View file

@ -7,6 +7,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"os"
"strings"
@ -195,18 +196,18 @@ type S3BucketKeyer interface {
S3BucketKey(path string) string
}
// URLFor attempts to find a url which may be used to retrieve the file at the given path.
// RedirectURL attempts to find a url which may be used to retrieve the file at the given path.
// Returns an error if the file cannot be found.
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
func (lh *cloudFrontStorageMiddleware) RedirectURL(r *http.Request, path string) (string, error) {
// TODO(endophage): currently only supports S3
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
if !ok {
dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.URLFor(ctx, path, options)
dcontext.GetLogger(r.Context()).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.RedirectURL(r, path)
}
if eligibleForS3(ctx, lh.awsIPs) {
return lh.StorageDriver.URLFor(ctx, path, options)
if eligibleForS3(r, lh.awsIPs) {
return lh.StorageDriver.RedirectURL(r, path)
}
// Get signed cloudfront url.

View file

@ -184,11 +184,7 @@ func (s *awsIPs) contains(ip net.IP) bool {
// parseIPFromRequest attempts to extract the ip address of the
// client that made the request
func parseIPFromRequest(ctx context.Context) (net.IP, error) {
request, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
func parseIPFromRequest(request *http.Request) (net.IP, error) {
ipStr := requestutil.RemoteIP(request)
ip := net.ParseIP(ipStr)
if ip == nil {
@ -200,25 +196,20 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) {
// eligibleForS3 checks if a request is eligible for using S3 directly
// Return true only when the IP belongs to a specific aws region and user-agent is docker
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool {
func eligibleForS3(request *http.Request, awsIPs *awsIPs) bool {
if awsIPs != nil && awsIPs.initialized {
if addr, err := parseIPFromRequest(ctx); err == nil {
request, err := dcontext.GetRequest(ctx)
if err != nil {
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err)
} else {
loggerField := map[interface{}]interface{}{
"user-client": request.UserAgent(),
"ip": requestutil.RemoteIP(request),
}
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
if addr, err := parseIPFromRequest(request); err == nil {
loggerField := map[interface{}]interface{}{
"user-client": request.UserAgent(),
"ip": requestutil.RemoteIP(request),
}
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(request.Context(), loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(request.Context(), loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
} else {
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
dcontext.GetLogger(request.Context()).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
}
}
return false

View file

@ -1,7 +1,6 @@
package middleware
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
@ -11,8 +10,6 @@ import (
"reflect" // used as a replacement for testify
"testing"
"time"
"github.com/distribution/distribution/v3/internal/dcontext"
)
// Rather than pull in all of testify
@ -269,29 +266,22 @@ func TestEligibleForS3(t *testing.T) {
}},
initialized: true,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
tests := []struct {
Context context.Context
Expected bool
RemoteAddr string
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: true},
{Context: makeContext("192.168.0.2"), Expected: false},
{RemoteAddr: "", Expected: false},
{RemoteAddr: "192.168.1.2", Expected: true},
{RemoteAddr: "192.168.0.2", Expected: false},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) {
t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) {
t.Parallel()
assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips))
req := &http.Request{RemoteAddr: tc.RemoteAddr}
assertEqual(t, tc.Expected, eligibleForS3(req, ips))
})
}
}
@ -305,29 +295,22 @@ func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
}},
initialized: false,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
tests := []struct {
Context context.Context
Expected bool
RemoteAddr string
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: false},
{Context: makeContext("192.168.0.2"), Expected: false},
{RemoteAddr: "", Expected: false},
{RemoteAddr: "192.168.1.2", Expected: false},
{RemoteAddr: "192.168.0.2", Expected: false},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) {
t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) {
t.Parallel()
assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips))
req := &http.Request{RemoteAddr: tc.RemoteAddr}
assertEqual(t, tc.Expected, eligibleForS3(req, ips))
})
}
}

View file

@ -1,8 +1,8 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/url"
"path"
@ -42,7 +42,7 @@ func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[st
return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host, basePath: u.Path}, nil
}
func (r *redirectStorageMiddleware) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) {
func (r *redirectStorageMiddleware) RedirectURL(_ *http.Request, urlPath string) (string, error) {
if r.basePath != "" {
urlPath = path.Join(r.basePath, urlPath)
}

View file

@ -1,7 +1,6 @@
package middleware
import (
"context"
"testing"
"gopkg.in/check.v1"
@ -37,7 +36,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(context.TODO(), "/rick/data", nil)
url, err := middleware.RedirectURL(nil, "/rick/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
}
@ -53,7 +52,7 @@ func (s *MiddlewareSuite) TestHTTP(c *check.C) {
c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(context.TODO(), "morty/data", nil)
url, err := middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data")
}
@ -71,12 +70,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
c.Assert(m.host, check.Equals, "example.com")
c.Assert(m.basePath, check.Equals, "/path")
// call URLFor() with no leading slash
url, err := middleware.URLFor(context.TODO(), "morty/data", nil)
// call RedirectURL() with no leading slash
url, err := middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
// call URLFor() with leading slash
url, err = middleware.URLFor(context.TODO(), "/morty/data", nil)
// call RedirectURL() with leading slash
url, err = middleware.RedirectURL(nil, "/morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
@ -91,12 +90,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
c.Assert(m.host, check.Equals, "example.com")
c.Assert(m.basePath, check.Equals, "/path/")
// call URLFor() with no leading slash
url, err = middleware.URLFor(context.TODO(), "morty/data", nil)
// call RedirectURL() with no leading slash
url, err = middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
// call URLFor() with leading slash
url, err = middleware.URLFor(context.TODO(), "/morty/data", nil)
// call RedirectURL() with leading slash
url, err = middleware.RedirectURL(nil, "/morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
}

View file

@ -1036,30 +1036,13 @@ func (d *driver) Delete(ctx context.Context, path string) error {
return nil
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
methodString := http.MethodGet
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
// RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
func (d *driver) RedirectURL(r *http.Request, path string) (string, error) {
expiresIn := 20 * time.Minute
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresIn = time.Until(et)
}
}
var req *request.Request
switch methodString {
switch r.Method {
case http.MethodGet:
req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String(d.Bucket),
@ -1071,7 +1054,7 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
Key: aws.String(d.s3Path(path)),
})
default:
panic("unreachable")
return "", nil
}
return req.Presign(expiresIn)

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
@ -92,11 +93,10 @@ type StorageDriver interface {
// Delete recursively deletes all objects stored at "path" and its subpaths.
Delete(ctx context.Context, path string) error
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error)
// RedirectURL returns a URL which the client of the request r may use
// to retrieve the content stored at path. Returning the empty string
// signals that the request may not be redirected.
RedirectURL(r *http.Request, path string) (string, error)
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file.

View file

@ -8,6 +8,7 @@ import (
"io"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path"
"sort"
@ -733,9 +734,9 @@ func (suite *DriverSuite) TestDelete(c *check.C) {
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
}
// TestURLFor checks that the URLFor method functions properly, but only if it
// is implemented
func (suite *DriverSuite) TestURLFor(c *check.C) {
// TestRedirectURL checks that the RedirectURL method functions properly,
// but only if it is implemented
func (suite *DriverSuite) TestRedirectURL(c *check.C) {
filename := randomPath(32)
contents := randomContents(32)
@ -744,8 +745,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
err := suite.StorageDriver.PutContent(suite.ctx, filename, contents)
c.Assert(err, check.IsNil)
url, err := suite.StorageDriver.URLFor(suite.ctx, filename, nil)
if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok {
url, err := suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodGet, filename, nil), filename)
if url == "" && err == nil {
return
}
c.Assert(err, check.IsNil)
@ -758,8 +759,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(read, check.DeepEquals, contents)
url, err = suite.StorageDriver.URLFor(suite.ctx, filename, map[string]interface{}{"method": http.MethodHead})
if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok {
url, err = suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodHead, filename, nil), filename)
if url == "" && err == nil {
return
}
c.Assert(err, check.IsNil)

View file

@ -34,7 +34,7 @@ type manifestURLs struct {
type RegistryOption func(*registry) error
// EnableRedirect is a functional option for NewRegistry. It causes the backend
// blob server to attempt using (StorageDriver).URLFor to serve all blobs.
// blob server to attempt using (StorageDriver).RedirectURL to serve all blobs.
func EnableRedirect(registry *registry) error {
registry.blobServer.redirect = true
return nil
@ -102,7 +102,7 @@ func BlobDescriptorCacheProvider(blobDescriptorCacheProvider cache.BlobDescripto
// NewRegistry creates a new registry instance from the provided driver. The
// resulting registry may be shared by multiple goroutines but is cheap to
// allocate. If the Redirect option is specified, the backend blob server will
// attempt to use (StorageDriver).URLFor to serve all blobs.
// attempt to use (StorageDriver).RedirectURL to serve all blobs.
func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, options ...RegistryOption) (distribution.Namespace, error) {
// create global statter
statter := &blobStatter{