forked from TrueCloudLab/distribution
b4dc4f3474
Signed-off-by: Cory Snider <csnider@mirantis.com>
229 lines
6.7 KiB
Go
229 lines
6.7 KiB
Go
// Package middleware - cloudfront wrapper for storage libs
|
|
// N.B. currently only works with S3, not arbitrary sites
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
|
|
dcontext "github.com/distribution/distribution/v3/context"
|
|
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
|
|
storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware"
|
|
)
|
|
|
|
// cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
|
|
// constructs temporary signed CloudFront URLs from the storagedriver layer URL,
|
|
// then issues HTTP Temporary Redirects to this CloudFront content URL.
|
|
type cloudFrontStorageMiddleware struct {
|
|
storagedriver.StorageDriver
|
|
awsIPs *awsIPs
|
|
urlSigner *sign.URLSigner
|
|
baseURL string
|
|
duration time.Duration
|
|
}
|
|
|
|
var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
|
|
|
|
// newCloudFrontLayerHandler constructs and returns a new CloudFront
|
|
// LayerHandler implementation.
|
|
//
|
|
// Required options:
|
|
//
|
|
// - baseurl
|
|
// - privatekey
|
|
// - keypairid
|
|
//
|
|
// Optional options:
|
|
//
|
|
// - ipFilteredBy
|
|
// - awsregion
|
|
// - ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP,
|
|
// default value. "aws", only aws IP goes to S3 directly. "awsregion", only
|
|
// regions listed in awsregion options goes to S3 directly
|
|
// - awsregion: a comma separated string of AWS regions.
|
|
func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
|
// parse baseurl
|
|
base, ok := options["baseurl"]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no baseurl provided")
|
|
}
|
|
baseURL, ok := base.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("baseurl must be a string")
|
|
}
|
|
if !strings.Contains(baseURL, "://") {
|
|
baseURL = "https://" + baseURL
|
|
}
|
|
if !strings.HasSuffix(baseURL, "/") {
|
|
baseURL += "/"
|
|
}
|
|
if _, err := url.Parse(baseURL); err != nil {
|
|
return nil, fmt.Errorf("invalid baseurl: %v", err)
|
|
}
|
|
|
|
// parse privatekey to get pkPath
|
|
pk, ok := options["privatekey"]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no privatekey provided")
|
|
}
|
|
pkPath, ok := pk.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("privatekey must be a string")
|
|
}
|
|
|
|
// parse keypairid
|
|
kpid, ok := options["keypairid"]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no keypairid provided")
|
|
}
|
|
keypairID, ok := kpid.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("keypairid must be a string")
|
|
}
|
|
|
|
// get urlSigner from the file specified in pkPath
|
|
pkBytes, err := os.ReadFile(pkPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read privatekey file: %s", err)
|
|
}
|
|
|
|
block, _ := pem.Decode(pkBytes)
|
|
if block == nil {
|
|
return nil, fmt.Errorf("failed to decode private key as an rsa private key")
|
|
}
|
|
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
urlSigner := sign.NewURLSigner(keypairID, privateKey)
|
|
|
|
// parse duration
|
|
duration := 20 * time.Minute
|
|
if d, ok := options["duration"]; ok {
|
|
switch d := d.(type) {
|
|
case time.Duration:
|
|
duration = d
|
|
case string:
|
|
dur, err := time.ParseDuration(d)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid duration: %s", err)
|
|
}
|
|
duration = dur
|
|
}
|
|
}
|
|
|
|
// parse updatefrequency
|
|
updateFrequency := defaultUpdateFrequency
|
|
// #2447 introduced a typo. Support it for backward compatibility.
|
|
if _, ok := options["updatefrenquency"]; ok {
|
|
options["updatefrequency"] = options["updatefrenquency"]
|
|
dcontext.GetLogger(context.Background()).Warn("cloudfront updatefrenquency is deprecated. Please use updatefrequency")
|
|
}
|
|
if u, ok := options["updatefrequency"]; ok {
|
|
switch u := u.(type) {
|
|
case time.Duration:
|
|
updateFrequency = u
|
|
case string:
|
|
updateFreq, err := time.ParseDuration(u)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid updatefrequency: %s", err)
|
|
}
|
|
updateFrequency = updateFreq
|
|
}
|
|
}
|
|
|
|
// parse iprangesurl
|
|
ipRangesURL := defaultIPRangesURL
|
|
if i, ok := options["iprangesurl"]; ok {
|
|
if iprangeurl, ok := i.(string); ok {
|
|
ipRangesURL = iprangeurl
|
|
} else {
|
|
return nil, fmt.Errorf("iprangesurl must be a string")
|
|
}
|
|
}
|
|
|
|
// parse ipfilteredby
|
|
var awsIPs *awsIPs
|
|
if i, ok := options["ipfilteredby"]; ok {
|
|
if ipFilteredBy, ok := i.(string); ok {
|
|
switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) {
|
|
case "", "none":
|
|
awsIPs = nil
|
|
case "aws":
|
|
awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case "awsregion":
|
|
var awsRegion []string
|
|
if i, ok := options["awsregion"]; ok {
|
|
if regions, ok := i.(string); ok {
|
|
for _, awsRegions := range strings.Split(regions, ",") {
|
|
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
|
|
}
|
|
awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("awsRegion is not defined")
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion")
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion")
|
|
}
|
|
}
|
|
|
|
return &cloudFrontStorageMiddleware{
|
|
StorageDriver: storageDriver,
|
|
urlSigner: urlSigner,
|
|
baseURL: baseURL,
|
|
duration: duration,
|
|
awsIPs: awsIPs,
|
|
}, nil
|
|
}
|
|
|
|
// S3BucketKeyer is any type that is capable of returning the S3 bucket key
|
|
// which should be cached by AWS CloudFront.
|
|
type S3BucketKeyer interface {
|
|
S3BucketKey(path string) string
|
|
}
|
|
|
|
// URLFor attempts to find a url which may be used to retrieve the file at the given path.
|
|
// Returns an error if the file cannot be found.
|
|
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
|
|
// TODO(endophage): currently only supports S3
|
|
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
|
|
if !ok {
|
|
dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
|
|
return lh.StorageDriver.URLFor(ctx, path, options)
|
|
}
|
|
|
|
if eligibleForS3(ctx, lh.awsIPs) {
|
|
return lh.StorageDriver.URLFor(ctx, path, options)
|
|
}
|
|
|
|
// Get signed cloudfront url.
|
|
cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return cfURL, nil
|
|
}
|
|
|
|
// init registers the cloudfront layerHandler backend.
|
|
func init() {
|
|
storagemiddleware.Register("cloudfront", newCloudFrontStorageMiddleware)
|
|
}
|