replace strings.Split(N) for strings.Cut() or alternatives

Go 1.18 and up now provides a strings.Cut() which is better suited for
splitting key/value pairs (and similar constructs), and performs better:

```go
func BenchmarkSplit(b *testing.B) {
	b.ReportAllocs()
	data := []string{"12hello=world", "12hello=", "12=hello", "12hello"}
	for i := 0; i < b.N; i++ {
		for _, s := range data {
			_ = strings.SplitN(s, "=", 2)[0]
		}
	}
}

func BenchmarkCut(b *testing.B) {
	b.ReportAllocs()
	data := []string{"12hello=world", "12hello=", "12=hello", "12hello"}
	for i := 0; i < b.N; i++ {
		for _, s := range data {
			_, _, _ = strings.Cut(s, "=")
		}
	}
}
```

    BenchmarkSplit
    BenchmarkSplit-10    	 8244206	       128.0 ns/op	     128 B/op	       4 allocs/op
    BenchmarkCut
    BenchmarkCut-10      	54411998	        21.80 ns/op	       0 B/op	       0 allocs/op

While looking at occurrences of `strings.Split()`, I also updated some for alternatives,
or added some constraints;

- for cases where an specific number of items is expected, I used `strings.SplitN()`
  with a suitable limit. This prevents (theoretical) unlimited splits.
- in some cases it we were using `strings.Split()`, but _actually_ were trying to match
  a prefix; for those I replaced the code to just match (and/or strip) the prefix.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2022-11-02 20:32:03 +01:00
parent 552b1526c6
commit 3b391d3290
No known key found for this signature in database
GPG key ID: 76698F39D527CE8C
10 changed files with 60 additions and 92 deletions

View file

@ -23,7 +23,7 @@ func MajorMinorVersion(major, minor uint) Version {
} }
func (version Version) major() (uint, error) { func (version Version) major() (uint, error) {
majorPart := strings.Split(string(version), ".")[0] majorPart, _, _ := strings.Cut(string(version), ".")
major, err := strconv.ParseUint(majorPart, 10, 0) major, err := strconv.ParseUint(majorPart, 10, 0)
return uint(major), err return uint(major), err
} }
@ -35,7 +35,7 @@ func (version Version) Major() uint {
} }
func (version Version) minor() (uint, error) { func (version Version) minor() (uint, error) {
minorPart := strings.Split(string(version), ".")[1] _, minorPart, _ := strings.Cut(string(version), ".")
minor, err := strconv.ParseUint(minorPart, 10, 0) minor, err := strconv.ParseUint(minorPart, 10, 0)
return uint(minor), err return uint(minor), err
} }
@ -89,8 +89,8 @@ func NewParser(prefix string, parseInfos []VersionedParseInfo) *Parser {
} }
for _, env := range os.Environ() { for _, env := range os.Environ() {
envParts := strings.SplitN(env, "=", 2) k, v, _ := strings.Cut(env, "=")
p.env = append(p.env, envVar{envParts[0], envParts[1]}) p.env = append(p.env, envVar{k, v})
} }
// We must sort the environment variables lexically by name so that // We must sort the environment variables lexically by name so that

View file

@ -32,14 +32,12 @@ func parseIP(ipStr string) net.IP {
// account proxy headers. // account proxy headers.
func RemoteAddr(r *http.Request) string { func RemoteAddr(r *http.Request) string {
if prior := r.Header.Get("X-Forwarded-For"); prior != "" { if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
proxies := strings.Split(prior, ",") remoteAddr, _, _ := strings.Cut(prior, ",")
if len(proxies) > 0 { remoteAddr = strings.Trim(remoteAddr, " ")
remoteAddr := strings.Trim(proxies[0], " ")
if parseIP(remoteAddr) != nil { if parseIP(remoteAddr) != nil {
return remoteAddr return remoteAddr
} }
} }
}
// X-Real-Ip is less supported, but worth checking in the // X-Real-Ip is less supported, but worth checking in the
// absence of X-Forwarded-For // absence of X-Forwarded-For
if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
@ -189,49 +187,37 @@ type httpRequestContext struct {
// "request.<component>". For example, r.RequestURI // "request.<component>". For example, r.RequestURI
func (ctx *httpRequestContext) Value(key interface{}) interface{} { func (ctx *httpRequestContext) Value(key interface{}) interface{} {
if keyStr, ok := key.(string); ok { if keyStr, ok := key.(string); ok {
if keyStr == "http.request" { switch keyStr {
case "http.request":
return ctx.r return ctx.r
} case "http.request.uri":
if !strings.HasPrefix(keyStr, "http.request.") {
goto fallback
}
parts := strings.Split(keyStr, ".")
if len(parts) != 3 {
goto fallback
}
switch parts[2] {
case "uri":
return ctx.r.RequestURI return ctx.r.RequestURI
case "remoteaddr": case "http.request.remoteaddr":
return RemoteAddr(ctx.r) return RemoteAddr(ctx.r)
case "method": case "http.request.method":
return ctx.r.Method return ctx.r.Method
case "host": case "http.request.host":
return ctx.r.Host return ctx.r.Host
case "referer": case "http.request.referer":
referer := ctx.r.Referer() referer := ctx.r.Referer()
if referer != "" { if referer != "" {
return referer return referer
} }
case "useragent": case "http.request.useragent":
return ctx.r.UserAgent() return ctx.r.UserAgent()
case "id": case "http.request.id":
return ctx.id return ctx.id
case "startedat": case "http.request.startedat":
return ctx.startedAt return ctx.startedAt
case "contenttype": case "http.request.contenttype":
ct := ctx.r.Header.Get("Content-Type") if ct := ctx.r.Header.Get("Content-Type"); ct != "" {
if ct != "" {
return ct return ct
} }
default:
// no match; fall back to standard behavior below
} }
} }
fallback:
return ctx.Context.Value(key) return ctx.Context.Value(key)
} }
@ -245,10 +231,9 @@ func (ctx *muxVarsContext) Value(key interface{}) interface{} {
if keyStr == "vars" { if keyStr == "vars" {
return ctx.vars return ctx.vars
} }
// TODO(thaJeztah): this considers "vars.FOO" and "FOO" to be equal.
keyStr = strings.TrimPrefix(keyStr, "vars.") // We need to check if that's intentional (could be a bug).
if v, ok := ctx.vars[strings.TrimPrefix(keyStr, "vars.")]; ok {
if v, ok := ctx.vars[keyStr]; ok {
return v return v
} }
} }
@ -300,36 +285,25 @@ func (irw *instrumentedResponseWriter) Flush() {
func (irw *instrumentedResponseWriter) Value(key interface{}) interface{} { func (irw *instrumentedResponseWriter) Value(key interface{}) interface{} {
if keyStr, ok := key.(string); ok { if keyStr, ok := key.(string); ok {
if keyStr == "http.response" { switch keyStr {
case "http.response":
return irw return irw
} case "http.response.written":
if !strings.HasPrefix(keyStr, "http.response.") {
goto fallback
}
parts := strings.Split(keyStr, ".")
if len(parts) != 3 {
goto fallback
}
irw.mu.Lock() irw.mu.Lock()
defer irw.mu.Unlock() defer irw.mu.Unlock()
switch parts[2] {
case "written":
return irw.written return irw.written
case "status": case "http.response.status":
irw.mu.Lock()
defer irw.mu.Unlock()
return irw.status return irw.status
case "contenttype": case "http.response.contenttype":
contentType := irw.Header().Get("Content-Type") if ct := irw.Header().Get("Content-Type"); ct != "" {
if contentType != "" { return ct
return contentType
} }
default:
// no match; fall back to standard behavior below
} }
} }
fallback:
return irw.Context.Value(key) return irw.Context.Value(key)
} }

View file

@ -80,8 +80,8 @@ func NewURLBuilderFromRequest(r *http.Request, relative bool) *URLBuilder {
// comma-separated list of hosts, to which each proxy appends the // comma-separated list of hosts, to which each proxy appends the
// requested host. We want to grab the first from this comma-separated // requested host. We want to grab the first from this comma-separated
// list. // list.
hosts := strings.SplitN(forwardedHost, ",", 2) host, _, _ = strings.Cut(forwardedHost, ",")
host = strings.TrimSpace(hosts[0]) host = strings.TrimSpace(host)
} }
} }

View file

@ -247,15 +247,12 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
return nil, err return nil, err
} }
parts := strings.Split(req.Header.Get("Authorization"), " ") prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ")
if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") {
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
challenge.err = ErrTokenRequired challenge.err = ErrTokenRequired
return nil, challenge return nil, challenge
} }
rawToken := parts[1]
token, err := NewToken(rawToken) token, err := NewToken(rawToken)
if err != nil { if err != nil {
challenge.err = err challenge.err = err

View file

@ -83,7 +83,9 @@ type VerifyOptions struct {
// NewToken parses the given raw token string // NewToken parses the given raw token string
// and constructs an unverified JSON Web Token. // and constructs an unverified JSON Web Token.
func NewToken(rawToken string) (*Token, error) { func NewToken(rawToken string) (*Token, error) {
parts := strings.Split(rawToken, TokenSeparator) // We expect 3 parts, but limit the split to 4 to detect cases where
// the token contains too many (or too few) separators.
parts := strings.SplitN(rawToken, TokenSeparator, 4)
if len(parts) != 3 { if len(parts) != 3 {
return nil, ErrMalformedToken return nil, ErrMalformedToken
} }

View file

@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -239,8 +240,8 @@ func (t *tags) All(ctx context.Context) ([]string, error) {
} }
tags = append(tags, tagsResponse.Tags...) tags = append(tags, tagsResponse.Tags...)
if link := resp.Header.Get("Link"); link != "" { if link := resp.Header.Get("Link"); link != "" {
linkURLStr := strings.Trim(strings.Split(link, ";")[0], "<>") firsLink, _, _ := strings.Cut(link, ";")
linkURL, err := url.Parse(linkURLStr) linkURL, err := url.Parse(strings.Trim(firsLink, "<>"))
if err != nil { if err != nil {
return tags, err return tags, err
} }
@ -808,8 +809,8 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
// TODO(dmcgowan): Check for invalid UUID // TODO(dmcgowan): Check for invalid UUID
uuid := resp.Header.Get("Docker-Upload-UUID") uuid := resp.Header.Get("Docker-Upload-UUID")
if uuid == "" { if uuid == "" {
parts := strings.Split(resp.Header.Get("Location"), "/") // uuid is expected to be the last path element
uuid = parts[len(parts)-1] _, uuid = path.Split(resp.Header.Get("Location"))
} }
if uuid == "" { if uuid == "" {
return nil, errors.New("cannot retrieve docker upload UUID") return nil, errors.New("cannot retrieve docker upload UUID")

View file

@ -68,19 +68,18 @@ func copyFullPayload(ctx context.Context, responseWriter http.ResponseWriter, r
return nil return nil
} }
func parseContentRange(cr string) (int64, int64, error) { func parseContentRange(cr string) (start int64, end int64, err error) {
ranges := strings.Split(cr, "-") rStart, rEnd, ok := strings.Cut(cr, "-")
if len(ranges) != 2 { if !ok {
return -1, -1, fmt.Errorf("invalid content range format, %s", cr) return -1, -1, fmt.Errorf("invalid content range format, %s", cr)
} }
start, err := strconv.ParseInt(ranges[0], 10, 64) start, err = strconv.ParseInt(rStart, 10, 64)
if err != nil { if err != nil {
return -1, -1, err return -1, -1, err
} }
end, err := strconv.ParseInt(ranges[1], 10, 64) end, err = strconv.ParseInt(rEnd, 10, 64)
if err != nil { if err != nil {
return -1, -1, err return -1, -1, err
} }
return start, end, nil return start, end, nil
} }

View file

@ -18,11 +18,10 @@ type logHook struct {
// Fire forwards an error to LogHook // Fire forwards an error to LogHook
func (hook *logHook) Fire(entry *logrus.Entry) error { func (hook *logHook) Fire(entry *logrus.Entry) error {
addr := strings.Split(hook.Mail.Addr, ":") host, _, ok := strings.Cut(hook.Mail.Addr, ":")
if len(addr) != 2 { if !ok || host == "" {
return errors.New("invalid Mail Address") return errors.New("invalid Mail Address")
} }
host := addr[0]
subject := fmt.Sprintf("[%s] %s: %s", entry.Level, host, entry.Message) subject := fmt.Sprintf("[%s] %s: %s", entry.Level, host, entry.Message)
html := ` html := `

View file

@ -17,14 +17,14 @@ type Version string
// Major returns the major (primary) component of a version. // Major returns the major (primary) component of a version.
func (version Version) Major() uint { func (version Version) Major() uint {
majorPart := strings.Split(string(version), ".")[0] majorPart, _, _ := strings.Cut(string(version), ".")
major, _ := strconv.ParseUint(majorPart, 10, 0) major, _ := strconv.ParseUint(majorPart, 10, 0)
return uint(major) return uint(major)
} }
// Minor returns the minor (secondary) component of a version. // Minor returns the minor (secondary) component of a version.
func (version Version) Minor() uint { func (version Version) Minor() uint {
minorPart := strings.Split(string(version), ".")[1] _, minorPart, _ := strings.Cut(string(version), ".")
minor, _ := strconv.ParseUint(minorPart, 10, 0) minor, _ := strconv.ParseUint(minorPart, 10, 0)
return uint(minor) return uint(minor)
} }

View file

@ -763,11 +763,7 @@ func chunkFilenames(slice []string, maxSize int) (chunks [][]string, err error)
} }
func parseManifest(manifest string) (container string, prefix string) { func parseManifest(manifest string) (container string, prefix string) {
components := strings.SplitN(manifest, "/", 2) container, prefix, _ = strings.Cut(manifest, "/")
container = components[0]
if len(components) > 1 {
prefix = components[1]
}
return container, prefix return container, prefix
} }