From 537de18ac35d271766390cd6604d5dcb3ba73fce Mon Sep 17 00:00:00 2001 From: AnnaShaleva Date: Wed, 2 Mar 2022 20:22:26 +0300 Subject: [PATCH] services: check Oracle response redirections 1. Move redirections check to the tcp level. Manually resolve request address and create connection for the first suitable resolved address. 2. Remove URIValidator. Redirections checks are set in the custom http client, so the user should take care of validation by himself when customizing the client. --- pkg/core/oracle_test.go | 28 +++++++++---------- pkg/services/oracle/network.go | 11 ++++---- pkg/services/oracle/oracle.go | 49 +++++++++++++++++++++++++--------- pkg/services/oracle/request.go | 16 +++++------ 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/pkg/core/oracle_test.go b/pkg/core/oracle_test.go index 22f575fa6..68e63d85c 100644 --- a/pkg/core/oracle_test.go +++ b/pkg/core/oracle_test.go @@ -8,7 +8,6 @@ import ( gio "io" "io/ioutil" "net/http" - "net/url" "os" "path" "path/filepath" @@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, return res.Container } -func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config { +func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config { return oracle.Config{ Log: zaptest.NewLogger(t), Network: netmode.UnitTestNet, @@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config }, }, Chain: bc, - Client: newDefaultHTTPClient(), + Client: newDefaultHTTPClient(returnOracleRedirectionErrOn), } } @@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) ( chan *transaction.Transaction) { m := make(map[uint64]*responseWithSig) ch := make(chan *transaction.Transaction, 5) - orcCfg := getOracleConfig(t, bc, walletPath, pass) + orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool { + return strings.HasPrefix(address, "https://private") + }) orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m} orcCfg.OnTransaction = saveTxToChan(ch) - orcCfg.URIValidator = func(u *url.URL) error { - if strings.HasPrefix(u.Host, "private") { - return errors.New("private network") - } - return nil - } orc, err := oracle.NewOracle(orcCfg) require.NoError(t, err) @@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) { func TestOracle_InvalidWallet(t *testing.T) { bc := newTestChain(t) - _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid")) + _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil)) require.Error(t, err) - _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one")) + _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil)) require.NoError(t, err) } @@ -551,7 +546,8 @@ type ( // httpClient implements oracle.HTTPClient with // mocked URL or responses. httpClient struct { - responses map[string]testResponse + returnOracleRedirectionErrOn func(address string) bool + responses map[string]testResponse } testResponse struct { @@ -563,6 +559,9 @@ type ( // Get implements oracle.HTTPClient interface. func (c *httpClient) Do(req *http.Request) (*http.Response, error) { + if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) { + return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect) + } resp, ok := c.responses[req.URL.String()] if ok { return &http.Response{ @@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) { return nil, errors.New("request failed") } -func newDefaultHTTPClient() oracle.HTTPClient { +func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient { return &httpClient{ + returnOracleRedirectionErrOn: returnOracleRedirectionErrOn, responses: map[string]testResponse{ "https://get.1234": { code: http.StatusOK, diff --git a/pkg/services/oracle/network.go b/pkg/services/oracle/network.go index 3ab8e0705..505ac9670 100644 --- a/pkg/services/oracle/network.go +++ b/pkg/services/oracle/network.go @@ -3,7 +3,6 @@ package oracle import ( "errors" "net" - "net/url" ) // reservedCIDRs is a list of ip addresses for private networks. @@ -32,15 +31,15 @@ func init() { } } -func defaultURIValidator(u *url.URL) error { - ip, err := net.ResolveIPAddr("ip", u.Hostname()) +func resolveAndCheck(network string, address string) (*net.IPAddr, error) { + ip, err := net.ResolveIPAddr(network, address) if err != nil { - return err + return nil, err } if isReserved(ip.IP) { - return errors.New("IP is not global unicast") + return nil, errors.New("IP is not global unicast") } - return nil + return ip, nil } func isReserved(ip net.IP) bool { diff --git a/pkg/services/oracle/oracle.go b/pkg/services/oracle/oracle.go index b63c96ebe..82a062f75 100644 --- a/pkg/services/oracle/oracle.go +++ b/pkg/services/oracle/oracle.go @@ -1,9 +1,11 @@ package oracle import ( + "context" "errors" + "fmt" + "net" "net/http" - "net/url" "sync" "time" @@ -80,7 +82,6 @@ type ( Chain Ledger ResponseHandler Broadcaster OnTransaction TxCallback - URIValidator URIValidator } // HTTPClient is an interface capable of doing oracle requests. @@ -99,8 +100,6 @@ type ( // TxCallback executes on new transactions when they are ready to be pooled. TxCallback = func(tx *transaction.Transaction) error - // URIValidator is used to check if provided URL is valid. - URIValidator = func(*url.URL) error ) const ( @@ -112,8 +111,15 @@ const ( // defaultRefreshInterval is default timeout for the failed request to be reprocessed. defaultRefreshInterval = time.Minute * 3 + + // maxRedirections is the number of allowed redirections for Oracle HTTPS request. + maxRedirections = 5 ) +// ErrRestrictedRedirect is returned when redirection to forbidden address occurs +// during Oracle response creation. +var ErrRestrictedRedirect = errors.New("oracle request redirection error") + // NewOracle returns new oracle instance. func NewOracle(cfg Config) (*Oracle, error) { o := &Oracle{ @@ -159,20 +165,39 @@ func NewOracle(cfg Config) (*Oracle, error) { return nil, errors.New("no wallet account could be unlocked") } - if o.Client == nil { - var client http.Client - client.Transport = &http.Transport{DisableKeepAlives: true} - client.Timeout = o.MainCfg.RequestTimeout - o.Client = &client - } if o.ResponseHandler == nil { o.ResponseHandler = defaultResponseHandler{} } if o.OnTransaction == nil { o.OnTransaction = func(*transaction.Transaction) error { return nil } } - if o.URIValidator == nil { - o.URIValidator = defaultURIValidator + if o.Client == nil { + var client http.Client + client.Transport = &http.Transport{ + DisableKeepAlives: true, + // Do not set DialTLSContext, so that DialContext will be used to establish the + // connection. After that TLS connection will be added to a persistent connection + // by standard library code and handshaking will be performed. + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + if !o.MainCfg.AllowPrivateHost { + ip, err := resolveAndCheck(network, address) + if err != nil { + return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err) + } + network = ip.Network() + address = ip.IP.String() + } + return net.Dial(network, address) + }, + } + client.Timeout = o.MainCfg.RequestTimeout + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694 + return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections) + } + return nil + } + o.Client = &client } return o, nil } diff --git a/pkg/services/oracle/request.go b/pkg/services/oracle/request.go index 360d84276..469888db4 100644 --- a/pkg/services/oracle/request.go +++ b/pkg/services/oracle/request.go @@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { } else { switch u.Scheme { case "https": - if !o.MainCfg.AllowPrivateHost { - err = o.URIValidator(u) - if err != nil { - o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL)) - resp.Code = transaction.Forbidden - break - } - } httpReq, err := http.NewRequest("GET", req.Req.URL, nil) if err != nil { o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err)) @@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { httpReq.Header.Set("Content-Type", "application/json") r, err := o.Client.Do(httpReq) if err != nil { - o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err)) - resp.Code = transaction.Error + if errors.Is(err, ErrRestrictedRedirect) { + resp.Code = transaction.Forbidden + } else { + resp.Code = transaction.Error + } + o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code)) break } switch r.StatusCode {