diff --git a/pkg/core/oracle_test.go b/pkg/core/oracle_test.go index 22f575fa6..68e63d85c 100644 --- a/pkg/core/oracle_test.go +++ b/pkg/core/oracle_test.go @@ -8,7 +8,6 @@ import ( gio "io" "io/ioutil" "net/http" - "net/url" "os" "path" "path/filepath" @@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, return res.Container } -func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config { +func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config { return oracle.Config{ Log: zaptest.NewLogger(t), Network: netmode.UnitTestNet, @@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config }, }, Chain: bc, - Client: newDefaultHTTPClient(), + Client: newDefaultHTTPClient(returnOracleRedirectionErrOn), } } @@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) ( chan *transaction.Transaction) { m := make(map[uint64]*responseWithSig) ch := make(chan *transaction.Transaction, 5) - orcCfg := getOracleConfig(t, bc, walletPath, pass) + orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool { + return strings.HasPrefix(address, "https://private") + }) orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m} orcCfg.OnTransaction = saveTxToChan(ch) - orcCfg.URIValidator = func(u *url.URL) error { - if strings.HasPrefix(u.Host, "private") { - return errors.New("private network") - } - return nil - } orc, err := oracle.NewOracle(orcCfg) require.NoError(t, err) @@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) { func TestOracle_InvalidWallet(t *testing.T) { bc := newTestChain(t) - _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid")) + _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil)) require.Error(t, err) - _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one")) + _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil)) require.NoError(t, err) } @@ -551,7 +546,8 @@ type ( // httpClient implements oracle.HTTPClient with // mocked URL or responses. httpClient struct { - responses map[string]testResponse + returnOracleRedirectionErrOn func(address string) bool + responses map[string]testResponse } testResponse struct { @@ -563,6 +559,9 @@ type ( // Get implements oracle.HTTPClient interface. func (c *httpClient) Do(req *http.Request) (*http.Response, error) { + if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) { + return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect) + } resp, ok := c.responses[req.URL.String()] if ok { return &http.Response{ @@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) { return nil, errors.New("request failed") } -func newDefaultHTTPClient() oracle.HTTPClient { +func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient { return &httpClient{ + returnOracleRedirectionErrOn: returnOracleRedirectionErrOn, responses: map[string]testResponse{ "https://get.1234": { code: http.StatusOK, diff --git a/pkg/services/oracle/network.go b/pkg/services/oracle/network.go index 3ab8e0705..3f07393e1 100644 --- a/pkg/services/oracle/network.go +++ b/pkg/services/oracle/network.go @@ -1,9 +1,12 @@ package oracle import ( - "errors" + "fmt" "net" - "net/url" + "net/http" + "syscall" + + "github.com/nspcc-dev/neo-go/pkg/config" ) // reservedCIDRs is a list of ip addresses for private networks. @@ -32,17 +35,6 @@ func init() { } } -func defaultURIValidator(u *url.URL) error { - ip, err := net.ResolveIPAddr("ip", u.Hostname()) - if err != nil { - return err - } - if isReserved(ip.IP) { - return errors.New("IP is not global unicast") - } - return nil -} - func isReserved(ip net.IP) bool { if !ip.IsGlobalUnicast() { return true @@ -54,3 +46,51 @@ func isReserved(ip net.IP) bool { } return false } + +func getDefaultClient(cfg config.OracleConfiguration) *http.Client { + d := &net.Dialer{} + if !cfg.AllowPrivateHost { + // Control is used after request URI is resolved and network connection (network + // file descriptor) is created, but right before the moment listening/dialing + // is started. + // `address` represents resolved IP address in the format of ip:port. `address` + // is presented in its final (resolved) form that was used directly for network + // connection establishing. + // Control is called for each item in the set of IP addresses got from request + // URI resolving. The first network connection with address that passes Control + // function will be used for further request processing. Network connection + // with address that failed Control will be ignored. If all the connections + // fail Control then the most relevant error (the one from the first address) + // will be returned after `Client.Do`. + d.Control = func(network, address string, c syscall.RawConn) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("%w: failed to split address %s: %s", ErrRestrictedRedirect, address, err) + } + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("%w: failed to parse IP address %s", ErrRestrictedRedirect, address) + } + if isReserved(ip) { + return fmt.Errorf("%w: IP is not global unicast", ErrRestrictedRedirect) + } + return nil + } + } + var client http.Client + client.Transport = &http.Transport{ + DisableKeepAlives: true, + // Do not set DialTLSContext, so that DialContext will be used to establish the + // connection. After that TLS connection will be added to a persistent connection + // by standard library code and handshaking will be performed. + DialContext: d.DialContext, + } + client.Timeout = cfg.RequestTimeout + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694 + return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections) + } + return nil + } + return &client +} diff --git a/pkg/services/oracle/network_test.go b/pkg/services/oracle/network_test.go index 2e1791c5d..37d623939 100644 --- a/pkg/services/oracle/network_test.go +++ b/pkg/services/oracle/network_test.go @@ -1,9 +1,13 @@ package oracle import ( + "errors" "net" + "strings" "testing" + "time" + "github.com/nspcc-dev/neo-go/pkg/config" "github.com/stretchr/testify/require" ) @@ -16,3 +20,30 @@ func TestIsReserved(t *testing.T) { require.False(t, isReserved(net.IPv4(8, 8, 8, 8))) } + +func TestDefaultClient_RestrictedRedirectErr(t *testing.T) { + cfg := config.OracleConfiguration{ + AllowPrivateHost: false, + RequestTimeout: time.Second, + } + cl := getDefaultClient(cfg) + + testCases := []string{ + "http://localhost:8080", + "http://localhost", + "https://localhost:443", + "https://" + net.IPv4zero.String(), + "https://" + net.IPv4(10, 0, 0, 1).String(), + "https://" + net.IPv4(192, 168, 0, 1).String(), + "https://[" + net.IPv6interfacelocalallnodes.String() + "]", + "https://[" + net.IPv6loopback.String() + "]", + } + for _, c := range testCases { + t.Run(c, func(t *testing.T) { + _, err := cl.Get(c) + require.Error(t, err) + require.True(t, errors.Is(err, ErrRestrictedRedirect), err) + require.True(t, strings.Contains(err.Error(), "IP is not global unicast"), err) + }) + } +} diff --git a/pkg/services/oracle/oracle.go b/pkg/services/oracle/oracle.go index b63c96ebe..5165ba5c5 100644 --- a/pkg/services/oracle/oracle.go +++ b/pkg/services/oracle/oracle.go @@ -3,7 +3,6 @@ package oracle import ( "errors" "net/http" - "net/url" "sync" "time" @@ -80,7 +79,6 @@ type ( Chain Ledger ResponseHandler Broadcaster OnTransaction TxCallback - URIValidator URIValidator } // HTTPClient is an interface capable of doing oracle requests. @@ -99,8 +97,6 @@ type ( // TxCallback executes on new transactions when they are ready to be pooled. TxCallback = func(tx *transaction.Transaction) error - // URIValidator is used to check if provided URL is valid. - URIValidator = func(*url.URL) error ) const ( @@ -112,8 +108,15 @@ const ( // defaultRefreshInterval is default timeout for the failed request to be reprocessed. defaultRefreshInterval = time.Minute * 3 + + // maxRedirections is the number of allowed redirections for Oracle HTTPS request. + maxRedirections = 5 ) +// ErrRestrictedRedirect is returned when redirection to forbidden address occurs +// during Oracle response creation. +var ErrRestrictedRedirect = errors.New("oracle request redirection error") + // NewOracle returns new oracle instance. func NewOracle(cfg Config) (*Oracle, error) { o := &Oracle{ @@ -159,20 +162,14 @@ func NewOracle(cfg Config) (*Oracle, error) { return nil, errors.New("no wallet account could be unlocked") } - if o.Client == nil { - var client http.Client - client.Transport = &http.Transport{DisableKeepAlives: true} - client.Timeout = o.MainCfg.RequestTimeout - o.Client = &client - } if o.ResponseHandler == nil { o.ResponseHandler = defaultResponseHandler{} } if o.OnTransaction == nil { o.OnTransaction = func(*transaction.Transaction) error { return nil } } - if o.URIValidator == nil { - o.URIValidator = defaultURIValidator + if o.Client == nil { + o.Client = getDefaultClient(o.MainCfg) } return o, nil } diff --git a/pkg/services/oracle/request.go b/pkg/services/oracle/request.go index 360d84276..469888db4 100644 --- a/pkg/services/oracle/request.go +++ b/pkg/services/oracle/request.go @@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { } else { switch u.Scheme { case "https": - if !o.MainCfg.AllowPrivateHost { - err = o.URIValidator(u) - if err != nil { - o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL)) - resp.Code = transaction.Forbidden - break - } - } httpReq, err := http.NewRequest("GET", req.Req.URL, nil) if err != nil { o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err)) @@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { httpReq.Header.Set("Content-Type", "application/json") r, err := o.Client.Do(httpReq) if err != nil { - o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err)) - resp.Code = transaction.Error + if errors.Is(err, ErrRestrictedRedirect) { + resp.Code = transaction.Forbidden + } else { + resp.Code = transaction.Error + } + o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code)) break } switch r.StatusCode {