diff --git a/ca/client.go b/ca/client.go index d080bfc2..47116245 100644 --- a/ca/client.go +++ b/ca/client.go @@ -153,16 +153,42 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { }) } -// parseEndpoint parses and validates the given endpoint +// parseEndpoint parses and validates the given endpoint. It supports general +// URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like +// ca.smallstep.com[:port][/path]. func parseEndpoint(endpoint string) (*url.URL, error) { u, err := url.Parse(endpoint) if err != nil { return nil, errors.Wrapf(err, "error parsing endpoint '%s'", endpoint) } - if u.Scheme == "" || u.Host == "" { - return nil, errors.Errorf("error parsing endpoint: url '%s' is not valid", endpoint) + + // URLs are generally parsed as: + // [scheme:][//[userinfo@]host][/]path[?query][#fragment] + // But URLs that do not start with a slash after the scheme are interpreted as + // scheme:opaque[?query][#fragment] + if u.Opaque == "" { + if u.Scheme == "" { + u.Scheme = "https" + } + if u.Host == "" { + // endpoint looks like ca.smallstep.com or ca.smallstep.com/1.0/sign + if u.Path != "" { + parts := strings.SplitN(u.Path, "/", 2) + u.Host = parts[0] + if len(parts) == 2 { + u.Path = parts[1] + } else { + u.Path = "" + } + return parseEndpoint(u.String()) + } + return nil, errors.Errorf("error parsing endpoint: url '%s' is not valid", endpoint) + } + return u, nil } - return u, nil + // scheme:opaque[?query][#fragment] + // endpoint looks like ca.smallstep.com:443 or ca.smallstep.com:443/1.0/sign + return parseEndpoint("https://" + endpoint) } // ProvisionerOption is the type of options passed to the Provisioner method. diff --git a/ca/client_test.go b/ca/client_test.go index 612a08da..6d5cd22a 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "reflect" "testing" "time" @@ -510,3 +511,41 @@ func TestClient_ProvisionerKey(t *testing.T) { }) } } + +func Test_parseEndpoint(t *testing.T) { + expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} + expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"} + type args struct { + endpoint string + } + tests := []struct { + name string + args args + want *url.URL + wantErr bool + }{ + {"ok", args{"https://ca.smallstep.com"}, expected1, false}, + {"ok no scheme", args{"//ca.smallstep.com"}, expected1, false}, + {"ok only host", args{"ca.smallstep.com"}, expected1, false}, + {"ok no bars", args{"https://ca.smallstep.com"}, expected1, false}, + {"ok schema, host and path", args{"https://ca.smallstep.com/1.0/sign"}, expected2, false}, + {"ok no bars with path", args{"https://ca.smallstep.com/1.0/sign"}, expected2, false}, + {"ok host and path", args{"ca.smallstep.com/1.0/sign"}, expected2, false}, + {"ok host and port", args{"ca.smallstep.com:443"}, &url.URL{Scheme: "https", Host: "ca.smallstep.com:443"}, false}, + {"ok host, path and port", args{"ca.smallstep.com:443/1.0/sign"}, &url.URL{Scheme: "https", Host: "ca.smallstep.com:443", Path: "/1.0/sign"}, false}, + {"fail bad url", args{"://ca.smallstep.com"}, nil, true}, + {"fail no host", args{"https://"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseEndpoint(tt.args.endpoint) + if (err != nil) != tt.wantErr { + t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseEndpoint() = %v, want %v", got, tt.want) + } + }) + } +}