diff --git a/pkg/core/oracle_test.go b/pkg/core/oracle_test.go index 22f575fa6..68e63d85c 100644 --- a/pkg/core/oracle_test.go +++ b/pkg/core/oracle_test.go @@ -8,7 +8,6 @@ import ( gio "io" "io/ioutil" "net/http" - "net/url" "os" "path" "path/filepath" @@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, return res.Container } -func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config { +func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config { return oracle.Config{ Log: zaptest.NewLogger(t), Network: netmode.UnitTestNet, @@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config }, }, Chain: bc, - Client: newDefaultHTTPClient(), + Client: newDefaultHTTPClient(returnOracleRedirectionErrOn), } } @@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) ( chan *transaction.Transaction) { m := make(map[uint64]*responseWithSig) ch := make(chan *transaction.Transaction, 5) - orcCfg := getOracleConfig(t, bc, walletPath, pass) + orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool { + return strings.HasPrefix(address, "https://private") + }) orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m} orcCfg.OnTransaction = saveTxToChan(ch) - orcCfg.URIValidator = func(u *url.URL) error { - if strings.HasPrefix(u.Host, "private") { - return errors.New("private network") - } - return nil - } orc, err := oracle.NewOracle(orcCfg) require.NoError(t, err) @@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) { func TestOracle_InvalidWallet(t *testing.T) { bc := newTestChain(t) - _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid")) + _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil)) require.Error(t, err) - _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one")) + _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil)) require.NoError(t, err) } @@ -551,7 +546,8 @@ type ( // httpClient implements oracle.HTTPClient with // mocked URL or responses. httpClient struct { - responses map[string]testResponse + returnOracleRedirectionErrOn func(address string) bool + responses map[string]testResponse } testResponse struct { @@ -563,6 +559,9 @@ type ( // Get implements oracle.HTTPClient interface. func (c *httpClient) Do(req *http.Request) (*http.Response, error) { + if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) { + return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect) + } resp, ok := c.responses[req.URL.String()] if ok { return &http.Response{ @@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) { return nil, errors.New("request failed") } -func newDefaultHTTPClient() oracle.HTTPClient { +func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient { return &httpClient{ + returnOracleRedirectionErrOn: returnOracleRedirectionErrOn, responses: map[string]testResponse{ "https://get.1234": { code: http.StatusOK, diff --git a/pkg/services/oracle/network.go b/pkg/services/oracle/network.go index 3ab8e0705..505ac9670 100644 --- a/pkg/services/oracle/network.go +++ b/pkg/services/oracle/network.go @@ -3,7 +3,6 @@ package oracle import ( "errors" "net" - "net/url" ) // reservedCIDRs is a list of ip addresses for private networks. @@ -32,15 +31,15 @@ func init() { } } -func defaultURIValidator(u *url.URL) error { - ip, err := net.ResolveIPAddr("ip", u.Hostname()) +func resolveAndCheck(network string, address string) (*net.IPAddr, error) { + ip, err := net.ResolveIPAddr(network, address) if err != nil { - return err + return nil, err } if isReserved(ip.IP) { - return errors.New("IP is not global unicast") + return nil, errors.New("IP is not global unicast") } - return nil + return ip, nil } func isReserved(ip net.IP) bool { diff --git a/pkg/services/oracle/oracle.go b/pkg/services/oracle/oracle.go index b63c96ebe..82a062f75 100644 --- a/pkg/services/oracle/oracle.go +++ b/pkg/services/oracle/oracle.go @@ -1,9 +1,11 @@ package oracle import ( + "context" "errors" + "fmt" + "net" "net/http" - "net/url" "sync" "time" @@ -80,7 +82,6 @@ type ( Chain Ledger ResponseHandler Broadcaster OnTransaction TxCallback - URIValidator URIValidator } // HTTPClient is an interface capable of doing oracle requests. @@ -99,8 +100,6 @@ type ( // TxCallback executes on new transactions when they are ready to be pooled. TxCallback = func(tx *transaction.Transaction) error - // URIValidator is used to check if provided URL is valid. - URIValidator = func(*url.URL) error ) const ( @@ -112,8 +111,15 @@ const ( // defaultRefreshInterval is default timeout for the failed request to be reprocessed. defaultRefreshInterval = time.Minute * 3 + + // maxRedirections is the number of allowed redirections for Oracle HTTPS request. + maxRedirections = 5 ) +// ErrRestrictedRedirect is returned when redirection to forbidden address occurs +// during Oracle response creation. +var ErrRestrictedRedirect = errors.New("oracle request redirection error") + // NewOracle returns new oracle instance. func NewOracle(cfg Config) (*Oracle, error) { o := &Oracle{ @@ -159,20 +165,39 @@ func NewOracle(cfg Config) (*Oracle, error) { return nil, errors.New("no wallet account could be unlocked") } - if o.Client == nil { - var client http.Client - client.Transport = &http.Transport{DisableKeepAlives: true} - client.Timeout = o.MainCfg.RequestTimeout - o.Client = &client - } if o.ResponseHandler == nil { o.ResponseHandler = defaultResponseHandler{} } if o.OnTransaction == nil { o.OnTransaction = func(*transaction.Transaction) error { return nil } } - if o.URIValidator == nil { - o.URIValidator = defaultURIValidator + if o.Client == nil { + var client http.Client + client.Transport = &http.Transport{ + DisableKeepAlives: true, + // Do not set DialTLSContext, so that DialContext will be used to establish the + // connection. After that TLS connection will be added to a persistent connection + // by standard library code and handshaking will be performed. + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + if !o.MainCfg.AllowPrivateHost { + ip, err := resolveAndCheck(network, address) + if err != nil { + return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err) + } + network = ip.Network() + address = ip.IP.String() + } + return net.Dial(network, address) + }, + } + client.Timeout = o.MainCfg.RequestTimeout + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694 + return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections) + } + return nil + } + o.Client = &client } return o, nil } diff --git a/pkg/services/oracle/request.go b/pkg/services/oracle/request.go index 360d84276..469888db4 100644 --- a/pkg/services/oracle/request.go +++ b/pkg/services/oracle/request.go @@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { } else { switch u.Scheme { case "https": - if !o.MainCfg.AllowPrivateHost { - err = o.URIValidator(u) - if err != nil { - o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL)) - resp.Code = transaction.Forbidden - break - } - } httpReq, err := http.NewRequest("GET", req.Req.URL, nil) if err != nil { o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err)) @@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { httpReq.Header.Set("Content-Type", "application/json") r, err := o.Client.Do(httpReq) if err != nil { - o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err)) - resp.Code = transaction.Error + if errors.Is(err, ErrRestrictedRedirect) { + resp.Code = transaction.Forbidden + } else { + resp.Code = transaction.Error + } + o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code)) break } switch r.StatusCode {