diff --git a/pkg/services/oracle/network.go b/pkg/services/oracle/network.go index 505ac9670..3f07393e1 100644 --- a/pkg/services/oracle/network.go +++ b/pkg/services/oracle/network.go @@ -1,8 +1,12 @@ package oracle import ( - "errors" + "fmt" "net" + "net/http" + "syscall" + + "github.com/nspcc-dev/neo-go/pkg/config" ) // reservedCIDRs is a list of ip addresses for private networks. @@ -31,17 +35,6 @@ func init() { } } -func resolveAndCheck(network string, address string) (*net.IPAddr, error) { - ip, err := net.ResolveIPAddr(network, address) - if err != nil { - return nil, err - } - if isReserved(ip.IP) { - return nil, errors.New("IP is not global unicast") - } - return ip, nil -} - func isReserved(ip net.IP) bool { if !ip.IsGlobalUnicast() { return true @@ -53,3 +46,51 @@ func isReserved(ip net.IP) bool { } return false } + +func getDefaultClient(cfg config.OracleConfiguration) *http.Client { + d := &net.Dialer{} + if !cfg.AllowPrivateHost { + // Control is used after request URI is resolved and network connection (network + // file descriptor) is created, but right before the moment listening/dialing + // is started. + // `address` represents resolved IP address in the format of ip:port. `address` + // is presented in its final (resolved) form that was used directly for network + // connection establishing. + // Control is called for each item in the set of IP addresses got from request + // URI resolving. The first network connection with address that passes Control + // function will be used for further request processing. Network connection + // with address that failed Control will be ignored. If all the connections + // fail Control then the most relevant error (the one from the first address) + // will be returned after `Client.Do`. + d.Control = func(network, address string, c syscall.RawConn) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("%w: failed to split address %s: %s", ErrRestrictedRedirect, address, err) + } + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("%w: failed to parse IP address %s", ErrRestrictedRedirect, address) + } + if isReserved(ip) { + return fmt.Errorf("%w: IP is not global unicast", ErrRestrictedRedirect) + } + return nil + } + } + var client http.Client + client.Transport = &http.Transport{ + DisableKeepAlives: true, + // Do not set DialTLSContext, so that DialContext will be used to establish the + // connection. After that TLS connection will be added to a persistent connection + // by standard library code and handshaking will be performed. + DialContext: d.DialContext, + } + client.Timeout = cfg.RequestTimeout + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694 + return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections) + } + return nil + } + return &client +} diff --git a/pkg/services/oracle/network_test.go b/pkg/services/oracle/network_test.go index 2e1791c5d..37d623939 100644 --- a/pkg/services/oracle/network_test.go +++ b/pkg/services/oracle/network_test.go @@ -1,9 +1,13 @@ package oracle import ( + "errors" "net" + "strings" "testing" + "time" + "github.com/nspcc-dev/neo-go/pkg/config" "github.com/stretchr/testify/require" ) @@ -16,3 +20,30 @@ func TestIsReserved(t *testing.T) { require.False(t, isReserved(net.IPv4(8, 8, 8, 8))) } + +func TestDefaultClient_RestrictedRedirectErr(t *testing.T) { + cfg := config.OracleConfiguration{ + AllowPrivateHost: false, + RequestTimeout: time.Second, + } + cl := getDefaultClient(cfg) + + testCases := []string{ + "http://localhost:8080", + "http://localhost", + "https://localhost:443", + "https://" + net.IPv4zero.String(), + "https://" + net.IPv4(10, 0, 0, 1).String(), + "https://" + net.IPv4(192, 168, 0, 1).String(), + "https://[" + net.IPv6interfacelocalallnodes.String() + "]", + "https://[" + net.IPv6loopback.String() + "]", + } + for _, c := range testCases { + t.Run(c, func(t *testing.T) { + _, err := cl.Get(c) + require.Error(t, err) + require.True(t, errors.Is(err, ErrRestrictedRedirect), err) + require.True(t, strings.Contains(err.Error(), "IP is not global unicast"), err) + }) + } +} diff --git a/pkg/services/oracle/oracle.go b/pkg/services/oracle/oracle.go index 82a062f75..5165ba5c5 100644 --- a/pkg/services/oracle/oracle.go +++ b/pkg/services/oracle/oracle.go @@ -1,10 +1,7 @@ package oracle import ( - "context" "errors" - "fmt" - "net" "net/http" "sync" "time" @@ -172,32 +169,7 @@ func NewOracle(cfg Config) (*Oracle, error) { o.OnTransaction = func(*transaction.Transaction) error { return nil } } if o.Client == nil { - var client http.Client - client.Transport = &http.Transport{ - DisableKeepAlives: true, - // Do not set DialTLSContext, so that DialContext will be used to establish the - // connection. After that TLS connection will be added to a persistent connection - // by standard library code and handshaking will be performed. - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - if !o.MainCfg.AllowPrivateHost { - ip, err := resolveAndCheck(network, address) - if err != nil { - return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err) - } - network = ip.Network() - address = ip.IP.String() - } - return net.Dial(network, address) - }, - } - client.Timeout = o.MainCfg.RequestTimeout - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694 - return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections) - } - return nil - } - o.Client = &client + o.Client = getDefaultClient(o.MainCfg) } return o, nil }