diff --git a/pkg/network/address.go b/pkg/network/address.go index 784e8058e5..11ffd19049 100644 --- a/pkg/network/address.go +++ b/pkg/network/address.go @@ -55,14 +55,19 @@ func (a *Address) FromString(s string) error { a.ma, err = multiaddr.NewMultiaddr(s) if err != nil { - var u uri + var ( + host string + hasTLS bool + ) + host, hasTLS, err = parseURI(s) + if err != nil { + host = s + } - u.parse(s) - - s, err = multiaddrStringFromHostAddr(u.host) + s, err = multiaddrStringFromHostAddr(host) if err == nil { a.ma, err = multiaddr.NewMultiaddr(s) - if err == nil && u.tls { + if err == nil && hasTLS { a.ma = a.ma.Encapsulate(tls) } } @@ -71,30 +76,41 @@ func (a *Address) FromString(s string) error { return err } -type uri struct { - host string - tls bool -} +const ( + grpcScheme = "grpc" + grpcTLSScheme = "grpcs" +) -const grpcTLSScheme = "grpcs" - -func (u *uri) parse(s string) { +// parseURI parses s as address and returns a host and a flag +// indicating that TLS is enabled. If multiaddress is provided +// the argument is returned unchanged. +func parseURI(s string) (string, bool, error) { // TODO: code is copy-pasted from client.WithURIAddress function. // Would be nice to share the code. uri, err := url.ParseRequestURI(s) - isURI := err == nil - - if isURI && uri.Host != "" { - u.host = uri.Host - } else { - u.host = s + if err != nil { + return s, false, nil } // check if passed string was parsed correctly // URIs that do not start with a slash after the scheme are interpreted as: // `scheme:opaque` => if `opaque` is not empty, then it is supposed that URI // is in `host:port` format - u.tls = isURI && uri.Opaque == "" && uri.Scheme == grpcTLSScheme + if uri.Host == "" { + uri.Host = uri.Scheme + uri.Scheme = grpcScheme // assume GRPC by default + if uri.Opaque != "" { + uri.Host = net.JoinHostPort(uri.Host, uri.Opaque) + } + } + + switch uri.Scheme { + case grpcTLSScheme, grpcScheme: + default: + return "", false, fmt.Errorf("unsupported scheme: %s", uri.Scheme) + } + + return uri.Host, uri.Scheme == grpcTLSScheme, nil } // multiaddrStringFromHostAddr converts "localhost:8080" to "/dns4/localhost/tcp/8080" diff --git a/pkg/network/address_test.go b/pkg/network/address_test.go index 91148a9e3f..c96d632a3c 100644 --- a/pkg/network/address_test.go +++ b/pkg/network/address_test.go @@ -29,6 +29,17 @@ func TestAddressFromString(t *testing.T) { require.Equal(t, testcase.exp, addr.ma, testcase.inp) } }) + t.Run("invalid addresses", func(t *testing.T) { + testCases := []string{ + "wtf://example.com:123", // wrong scheme + "grpc://example.com", // missing port + } + + var addr Address + for _, tc := range testCases { + require.Error(t, addr.FromString(tc)) + } + }) } func TestAddress_HostAddrString(t *testing.T) {