224 lines
5.9 KiB
Go
224 lines
5.9 KiB
Go
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
dcontext "github.com/docker/distribution/context"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
// ipRangesURL is the URL to get definition of AWS IPs
|
||
|
defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
|
||
|
// updateFrequency tells how frequently AWS IPs need to be updated
|
||
|
defaultUpdateFrequency = time.Hour * 12
|
||
|
)
|
||
|
|
||
|
// newAWSIPs returns a New awsIP object.
|
||
|
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
|
||
|
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
|
||
|
ips := &awsIPs{
|
||
|
host: host,
|
||
|
updateFrequency: updateFrequency,
|
||
|
awsRegion: awsRegion,
|
||
|
updaterStopChan: make(chan bool),
|
||
|
}
|
||
|
if err := ips.tryUpdate(); err != nil {
|
||
|
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
|
||
|
}
|
||
|
go ips.updater()
|
||
|
return ips
|
||
|
}
|
||
|
|
||
|
// awsIPs tracks a list of AWS ips, filtered by awsRegion
|
||
|
type awsIPs struct {
|
||
|
host string
|
||
|
updateFrequency time.Duration
|
||
|
ipv4 []net.IPNet
|
||
|
ipv6 []net.IPNet
|
||
|
mutex sync.RWMutex
|
||
|
awsRegion []string
|
||
|
updaterStopChan chan bool
|
||
|
initialized bool
|
||
|
}
|
||
|
|
||
|
type awsIPResponse struct {
|
||
|
Prefixes []prefixEntry `json:"prefixes"`
|
||
|
V6Prefixes []prefixEntry `json:"ipv6_prefixes"`
|
||
|
}
|
||
|
|
||
|
type prefixEntry struct {
|
||
|
IPV4Prefix string `json:"ip_prefix"`
|
||
|
IPV6Prefix string `json:"ipv6_prefix"`
|
||
|
Region string `json:"region"`
|
||
|
Service string `json:"service"`
|
||
|
}
|
||
|
|
||
|
func fetchAWSIPs(url string) (awsIPResponse, error) {
|
||
|
var response awsIPResponse
|
||
|
resp, err := http.Get(url)
|
||
|
if err != nil {
|
||
|
return response, err
|
||
|
}
|
||
|
if resp.StatusCode != 200 {
|
||
|
body, _ := ioutil.ReadAll(resp.Body)
|
||
|
return response, fmt.Errorf("failed to fetch network data. response = %s", body)
|
||
|
}
|
||
|
decoder := json.NewDecoder(resp.Body)
|
||
|
err = decoder.Decode(&response)
|
||
|
if err != nil {
|
||
|
return response, err
|
||
|
}
|
||
|
return response, nil
|
||
|
}
|
||
|
|
||
|
// tryUpdate attempts to download the new set of ip addresses.
|
||
|
// tryUpdate must be thread safe with contains
|
||
|
func (s *awsIPs) tryUpdate() error {
|
||
|
response, err := fetchAWSIPs(s.host)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
var ipv4 []net.IPNet
|
||
|
var ipv6 []net.IPNet
|
||
|
|
||
|
processAddress := func(output *[]net.IPNet, prefix string, region string) {
|
||
|
regionAllowed := false
|
||
|
if len(s.awsRegion) > 0 {
|
||
|
for _, ar := range s.awsRegion {
|
||
|
if strings.ToLower(region) == ar {
|
||
|
regionAllowed = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
regionAllowed = true
|
||
|
}
|
||
|
|
||
|
_, network, err := net.ParseCIDR(prefix)
|
||
|
if err != nil {
|
||
|
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
|
||
|
"cidr": prefix,
|
||
|
}).Error("unparseable cidr")
|
||
|
return
|
||
|
}
|
||
|
if regionAllowed {
|
||
|
*output = append(*output, *network)
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
for _, prefix := range response.Prefixes {
|
||
|
processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region)
|
||
|
}
|
||
|
for _, prefix := range response.V6Prefixes {
|
||
|
processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region)
|
||
|
}
|
||
|
s.mutex.Lock()
|
||
|
defer s.mutex.Unlock()
|
||
|
// Update each attr of awsips atomically.
|
||
|
s.ipv4 = ipv4
|
||
|
s.ipv6 = ipv6
|
||
|
s.initialized = true
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// This function is meant to be run in a background goroutine.
|
||
|
// It will periodically update the ips from aws.
|
||
|
func (s *awsIPs) updater() {
|
||
|
defer close(s.updaterStopChan)
|
||
|
for {
|
||
|
time.Sleep(s.updateFrequency)
|
||
|
select {
|
||
|
case <-s.updaterStopChan:
|
||
|
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
|
||
|
return
|
||
|
default:
|
||
|
err := s.tryUpdate()
|
||
|
if err != nil {
|
||
|
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// getCandidateNetworks returns either the ipv4 or ipv6 networks
|
||
|
// that were last read from aws. The networks returned
|
||
|
// have the same type as the ip address provided.
|
||
|
func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet {
|
||
|
s.mutex.RLock()
|
||
|
defer s.mutex.RUnlock()
|
||
|
if ip.To4() != nil {
|
||
|
return s.ipv4
|
||
|
} else if ip.To16() != nil {
|
||
|
return s.ipv6
|
||
|
} else {
|
||
|
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
|
||
|
"ip": ip,
|
||
|
}).Error("unknown ip address format")
|
||
|
// assume mismatch, pass through cloudfront
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Contains determines whether the host is within aws.
|
||
|
func (s *awsIPs) contains(ip net.IP) bool {
|
||
|
networks := s.getCandidateNetworks(ip)
|
||
|
for _, network := range networks {
|
||
|
if network.Contains(ip) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// parseIPFromRequest attempts to extract the ip address of the
|
||
|
// client that made the request
|
||
|
func parseIPFromRequest(ctx context.Context) (net.IP, error) {
|
||
|
request, err := dcontext.GetRequest(ctx)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
ipStr := dcontext.RemoteIP(request)
|
||
|
ip := net.ParseIP(ipStr)
|
||
|
if ip == nil {
|
||
|
return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr)
|
||
|
}
|
||
|
|
||
|
return ip, nil
|
||
|
}
|
||
|
|
||
|
// eligibleForS3 checks if a request is eligible for using S3 directly
|
||
|
// Return true only when the IP belongs to a specific aws region and user-agent is docker
|
||
|
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool {
|
||
|
if awsIPs != nil && awsIPs.initialized {
|
||
|
if addr, err := parseIPFromRequest(ctx); err == nil {
|
||
|
request, err := dcontext.GetRequest(ctx)
|
||
|
if err != nil {
|
||
|
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err)
|
||
|
} else {
|
||
|
loggerField := map[interface{}]interface{}{
|
||
|
"user-client": request.UserAgent(),
|
||
|
"ip": dcontext.RemoteIP(request),
|
||
|
}
|
||
|
if awsIPs.contains(addr) {
|
||
|
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
|
||
|
return true
|
||
|
}
|
||
|
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
|
||
|
}
|
||
|
} else {
|
||
|
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|