From 552b1526c6821a84daab48cbc7f5456ae215d6c4 Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Wed, 9 Nov 2022 15:20:23 +0100 Subject: [PATCH 1/3] reference: check for prefix instead of splitting, and use consts - use strings.HasPrefix() to check for the prefix we're interested in instead of doing a strings.Split() without limits. This makes the code both easier to read, and prevents potential situations where we end up with a long slice. - use consts for defaults; these should never be modified, so better to use consts for them to indicate they're fixed values. Signed-off-by: Sebastiaan van Stijn --- reference/normalize.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/reference/normalize.go b/reference/normalize.go index 9a23ffa9d..47de5850a 100644 --- a/reference/normalize.go +++ b/reference/normalize.go @@ -7,10 +7,10 @@ import ( "github.com/opencontainers/go-digest" ) -var ( +const ( legacyDefaultDomain = "index.docker.io" defaultDomain = "docker.io" - officialRepoName = "library" + officialRepoPrefix = "library/" defaultTag = "latest" ) @@ -96,8 +96,13 @@ func splitDockerDomain(name string) (domain, remainder string) { if domain == legacyDefaultDomain { domain = defaultDomain } + // TODO(thaJeztah): this check may be too strict, as it assumes the + // "library/" namespace does not have nested namespaces. While this + // is true (currently), technically it would be possible for Docker + // Hub to use those (e.g. "library/distros/ubuntu:latest"). + // See https://github.com/distribution/distribution/pull/3769#issuecomment-1302031785. if domain == defaultDomain && !strings.ContainsRune(remainder, '/') { - remainder = officialRepoName + "/" + remainder + remainder = officialRepoPrefix + remainder } return } @@ -117,8 +122,15 @@ func familiarizeName(named namedRepository) repository { if repo.domain == defaultDomain { repo.domain = "" // Handle official repositories which have the pattern "library/" - if split := strings.Split(repo.path, "/"); len(split) == 2 && split[0] == officialRepoName { - repo.path = split[1] + if strings.HasPrefix(repo.path, officialRepoPrefix) { + // TODO(thaJeztah): this check may be too strict, as it assumes the + // "library/" namespace does not have nested namespaces. While this + // is true (currently), technically it would be possible for Docker + // Hub to use those (e.g. "library/distros/ubuntu:latest"). + // See https://github.com/distribution/distribution/pull/3769#issuecomment-1302031785. + if remainder := strings.TrimPrefix(repo.path, officialRepoPrefix); !strings.ContainsRune(remainder, '/') { + repo.path = remainder + } } } return repo From 3b391d32908f3523b525ac9e5785a24a9f488919 Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Wed, 2 Nov 2022 20:32:03 +0100 Subject: [PATCH 2/3] replace strings.Split(N) for strings.Cut() or alternatives Go 1.18 and up now provides a strings.Cut() which is better suited for splitting key/value pairs (and similar constructs), and performs better: ```go func BenchmarkSplit(b *testing.B) { b.ReportAllocs() data := []string{"12hello=world", "12hello=", "12=hello", "12hello"} for i := 0; i < b.N; i++ { for _, s := range data { _ = strings.SplitN(s, "=", 2)[0] } } } func BenchmarkCut(b *testing.B) { b.ReportAllocs() data := []string{"12hello=world", "12hello=", "12=hello", "12hello"} for i := 0; i < b.N; i++ { for _, s := range data { _, _, _ = strings.Cut(s, "=") } } } ``` BenchmarkSplit BenchmarkSplit-10 8244206 128.0 ns/op 128 B/op 4 allocs/op BenchmarkCut BenchmarkCut-10 54411998 21.80 ns/op 0 B/op 0 allocs/op While looking at occurrences of `strings.Split()`, I also updated some for alternatives, or added some constraints; - for cases where an specific number of items is expected, I used `strings.SplitN()` with a suitable limit. This prevents (theoretical) unlimited splits. - in some cases it we were using `strings.Split()`, but _actually_ were trying to match a prefix; for those I replaced the code to just match (and/or strip) the prefix. Signed-off-by: Sebastiaan van Stijn --- configuration/parser.go | 8 +- context/http.go | 94 +++++++++--------------- registry/api/v2/urls.go | 4 +- registry/auth/token/accesscontroller.go | 7 +- registry/auth/token/token.go | 4 +- registry/client/repository.go | 9 ++- registry/handlers/helpers.go | 11 ++- registry/handlers/hooks.go | 5 +- registry/storage/driver/storagedriver.go | 4 +- registry/storage/driver/swift/swift.go | 6 +- 10 files changed, 60 insertions(+), 92 deletions(-) diff --git a/configuration/parser.go b/configuration/parser.go index 2b389f1f3..6d6451800 100644 --- a/configuration/parser.go +++ b/configuration/parser.go @@ -23,7 +23,7 @@ func MajorMinorVersion(major, minor uint) Version { } func (version Version) major() (uint, error) { - majorPart := strings.Split(string(version), ".")[0] + majorPart, _, _ := strings.Cut(string(version), ".") major, err := strconv.ParseUint(majorPart, 10, 0) return uint(major), err } @@ -35,7 +35,7 @@ func (version Version) Major() uint { } func (version Version) minor() (uint, error) { - minorPart := strings.Split(string(version), ".")[1] + _, minorPart, _ := strings.Cut(string(version), ".") minor, err := strconv.ParseUint(minorPart, 10, 0) return uint(minor), err } @@ -89,8 +89,8 @@ func NewParser(prefix string, parseInfos []VersionedParseInfo) *Parser { } for _, env := range os.Environ() { - envParts := strings.SplitN(env, "=", 2) - p.env = append(p.env, envVar{envParts[0], envParts[1]}) + k, v, _ := strings.Cut(env, "=") + p.env = append(p.env, envVar{k, v}) } // We must sort the environment variables lexically by name so that diff --git a/context/http.go b/context/http.go index bd3ebd955..68f977ce0 100644 --- a/context/http.go +++ b/context/http.go @@ -32,12 +32,10 @@ func parseIP(ipStr string) net.IP { // account proxy headers. func RemoteAddr(r *http.Request) string { if prior := r.Header.Get("X-Forwarded-For"); prior != "" { - proxies := strings.Split(prior, ",") - if len(proxies) > 0 { - remoteAddr := strings.Trim(proxies[0], " ") - if parseIP(remoteAddr) != nil { - return remoteAddr - } + remoteAddr, _, _ := strings.Cut(prior, ",") + remoteAddr = strings.Trim(remoteAddr, " ") + if parseIP(remoteAddr) != nil { + return remoteAddr } } // X-Real-Ip is less supported, but worth checking in the @@ -189,49 +187,37 @@ type httpRequestContext struct { // "request.". For example, r.RequestURI func (ctx *httpRequestContext) Value(key interface{}) interface{} { if keyStr, ok := key.(string); ok { - if keyStr == "http.request" { + switch keyStr { + case "http.request": return ctx.r - } - - if !strings.HasPrefix(keyStr, "http.request.") { - goto fallback - } - - parts := strings.Split(keyStr, ".") - - if len(parts) != 3 { - goto fallback - } - - switch parts[2] { - case "uri": + case "http.request.uri": return ctx.r.RequestURI - case "remoteaddr": + case "http.request.remoteaddr": return RemoteAddr(ctx.r) - case "method": + case "http.request.method": return ctx.r.Method - case "host": + case "http.request.host": return ctx.r.Host - case "referer": + case "http.request.referer": referer := ctx.r.Referer() if referer != "" { return referer } - case "useragent": + case "http.request.useragent": return ctx.r.UserAgent() - case "id": + case "http.request.id": return ctx.id - case "startedat": + case "http.request.startedat": return ctx.startedAt - case "contenttype": - ct := ctx.r.Header.Get("Content-Type") - if ct != "" { + case "http.request.contenttype": + if ct := ctx.r.Header.Get("Content-Type"); ct != "" { return ct } + default: + // no match; fall back to standard behavior below } } -fallback: return ctx.Context.Value(key) } @@ -245,10 +231,9 @@ func (ctx *muxVarsContext) Value(key interface{}) interface{} { if keyStr == "vars" { return ctx.vars } - - keyStr = strings.TrimPrefix(keyStr, "vars.") - - if v, ok := ctx.vars[keyStr]; ok { + // TODO(thaJeztah): this considers "vars.FOO" and "FOO" to be equal. + // We need to check if that's intentional (could be a bug). + if v, ok := ctx.vars[strings.TrimPrefix(keyStr, "vars.")]; ok { return v } } @@ -300,36 +285,25 @@ func (irw *instrumentedResponseWriter) Flush() { func (irw *instrumentedResponseWriter) Value(key interface{}) interface{} { if keyStr, ok := key.(string); ok { - if keyStr == "http.response" { + switch keyStr { + case "http.response": return irw - } - - if !strings.HasPrefix(keyStr, "http.response.") { - goto fallback - } - - parts := strings.Split(keyStr, ".") - - if len(parts) != 3 { - goto fallback - } - - irw.mu.Lock() - defer irw.mu.Unlock() - - switch parts[2] { - case "written": + case "http.response.written": + irw.mu.Lock() + defer irw.mu.Unlock() return irw.written - case "status": + case "http.response.status": + irw.mu.Lock() + defer irw.mu.Unlock() return irw.status - case "contenttype": - contentType := irw.Header().Get("Content-Type") - if contentType != "" { - return contentType + case "http.response.contenttype": + if ct := irw.Header().Get("Content-Type"); ct != "" { + return ct } + default: + // no match; fall back to standard behavior below } } -fallback: return irw.Context.Value(key) } diff --git a/registry/api/v2/urls.go b/registry/api/v2/urls.go index c4fdf4153..f4aa90954 100644 --- a/registry/api/v2/urls.go +++ b/registry/api/v2/urls.go @@ -80,8 +80,8 @@ func NewURLBuilderFromRequest(r *http.Request, relative bool) *URLBuilder { // comma-separated list of hosts, to which each proxy appends the // requested host. We want to grab the first from this comma-separated // list. - hosts := strings.SplitN(forwardedHost, ",", 2) - host = strings.TrimSpace(hosts[0]) + host, _, _ = strings.Cut(forwardedHost, ",") + host = strings.TrimSpace(host) } } diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index e56a4ccdd..26a18a979 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -247,15 +247,12 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. return nil, err } - parts := strings.Split(req.Header.Get("Authorization"), " ") - - if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ") + if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") { challenge.err = ErrTokenRequired return nil, challenge } - rawToken := parts[1] - token, err := NewToken(rawToken) if err != nil { challenge.err = err diff --git a/registry/auth/token/token.go b/registry/auth/token/token.go index fd195b521..da0167002 100644 --- a/registry/auth/token/token.go +++ b/registry/auth/token/token.go @@ -83,7 +83,9 @@ type VerifyOptions struct { // NewToken parses the given raw token string // and constructs an unverified JSON Web Token. func NewToken(rawToken string) (*Token, error) { - parts := strings.Split(rawToken, TokenSeparator) + // We expect 3 parts, but limit the split to 4 to detect cases where + // the token contains too many (or too few) separators. + parts := strings.SplitN(rawToken, TokenSeparator, 4) if len(parts) != 3 { return nil, ErrMalformedToken } diff --git a/registry/client/repository.go b/registry/client/repository.go index 10c6fe3f0..9e3489388 100644 --- a/registry/client/repository.go +++ b/registry/client/repository.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "path" "strconv" "strings" "time" @@ -239,8 +240,8 @@ func (t *tags) All(ctx context.Context) ([]string, error) { } tags = append(tags, tagsResponse.Tags...) if link := resp.Header.Get("Link"); link != "" { - linkURLStr := strings.Trim(strings.Split(link, ";")[0], "<>") - linkURL, err := url.Parse(linkURLStr) + firsLink, _, _ := strings.Cut(link, ";") + linkURL, err := url.Parse(strings.Trim(firsLink, "<>")) if err != nil { return tags, err } @@ -808,8 +809,8 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO // TODO(dmcgowan): Check for invalid UUID uuid := resp.Header.Get("Docker-Upload-UUID") if uuid == "" { - parts := strings.Split(resp.Header.Get("Location"), "/") - uuid = parts[len(parts)-1] + // uuid is expected to be the last path element + _, uuid = path.Split(resp.Header.Get("Location")) } if uuid == "" { return nil, errors.New("cannot retrieve docker upload UUID") diff --git a/registry/handlers/helpers.go b/registry/handlers/helpers.go index a872fb3c1..d70306fd7 100644 --- a/registry/handlers/helpers.go +++ b/registry/handlers/helpers.go @@ -68,19 +68,18 @@ func copyFullPayload(ctx context.Context, responseWriter http.ResponseWriter, r return nil } -func parseContentRange(cr string) (int64, int64, error) { - ranges := strings.Split(cr, "-") - if len(ranges) != 2 { +func parseContentRange(cr string) (start int64, end int64, err error) { + rStart, rEnd, ok := strings.Cut(cr, "-") + if !ok { return -1, -1, fmt.Errorf("invalid content range format, %s", cr) } - start, err := strconv.ParseInt(ranges[0], 10, 64) + start, err = strconv.ParseInt(rStart, 10, 64) if err != nil { return -1, -1, err } - end, err := strconv.ParseInt(ranges[1], 10, 64) + end, err = strconv.ParseInt(rEnd, 10, 64) if err != nil { return -1, -1, err } - return start, end, nil } diff --git a/registry/handlers/hooks.go b/registry/handlers/hooks.go index 4a2b1be84..76b39aedf 100644 --- a/registry/handlers/hooks.go +++ b/registry/handlers/hooks.go @@ -18,11 +18,10 @@ type logHook struct { // Fire forwards an error to LogHook func (hook *logHook) Fire(entry *logrus.Entry) error { - addr := strings.Split(hook.Mail.Addr, ":") - if len(addr) != 2 { + host, _, ok := strings.Cut(hook.Mail.Addr, ":") + if !ok || host == "" { return errors.New("invalid Mail Address") } - host := addr[0] subject := fmt.Sprintf("[%s] %s: %s", entry.Level, host, entry.Message) html := ` diff --git a/registry/storage/driver/storagedriver.go b/registry/storage/driver/storagedriver.go index bcbeda5fd..d573e6176 100644 --- a/registry/storage/driver/storagedriver.go +++ b/registry/storage/driver/storagedriver.go @@ -17,14 +17,14 @@ type Version string // Major returns the major (primary) component of a version. func (version Version) Major() uint { - majorPart := strings.Split(string(version), ".")[0] + majorPart, _, _ := strings.Cut(string(version), ".") major, _ := strconv.ParseUint(majorPart, 10, 0) return uint(major) } // Minor returns the minor (secondary) component of a version. func (version Version) Minor() uint { - minorPart := strings.Split(string(version), ".")[1] + _, minorPart, _ := strings.Cut(string(version), ".") minor, _ := strconv.ParseUint(minorPart, 10, 0) return uint(minor) } diff --git a/registry/storage/driver/swift/swift.go b/registry/storage/driver/swift/swift.go index b5c4dcd4d..4becc7284 100644 --- a/registry/storage/driver/swift/swift.go +++ b/registry/storage/driver/swift/swift.go @@ -763,11 +763,7 @@ func chunkFilenames(slice []string, maxSize int) (chunks [][]string, err error) } func parseManifest(manifest string) (container string, prefix string) { - components := strings.SplitN(manifest, "/", 2) - container = components[0] - if len(components) > 1 { - prefix = components[1] - } + container, prefix, _ = strings.Cut(manifest, "/") return container, prefix } From 842d4c04f50fbad367ef3d2ebcffa66a982f973d Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Thu, 3 Nov 2022 00:50:54 +0100 Subject: [PATCH 3/3] cloudfront: use strings.Equalfold() Minor optimization :) Signed-off-by: Sebastiaan van Stijn --- registry/storage/driver/middleware/cloudfront/s3filter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 6b02d29de..1ff9829ee 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -94,7 +94,7 @@ func (s *awsIPs) tryUpdate() error { regionAllowed := false if len(s.awsRegion) > 0 { for _, ar := range s.awsRegion { - if strings.ToLower(region) == ar { + if strings.EqualFold(region, ar) { regionAllowed = true break }