services: check Oracle response redirections

1. Move redirections check to the tcp level. Manually resolve request address
and create connection for the first suitable resolved address.
2. Remove URIValidator. Redirections checks are set in the custom http client,
so the user should take care of validation by himself when customizing the
client.
This commit is contained in:
AnnaShaleva 2022-03-02 20:22:26 +03:00 committed by Anna Shaleva
parent 26b76ed858
commit 537de18ac3
4 changed files with 62 additions and 42 deletions

View file

@ -8,7 +8,6 @@ import (
gio "io" gio "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
return res.Container return res.Container
} }
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config { func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config {
return oracle.Config{ return oracle.Config{
Log: zaptest.NewLogger(t), Log: zaptest.NewLogger(t),
Network: netmode.UnitTestNet, Network: netmode.UnitTestNet,
@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config
}, },
}, },
Chain: bc, Chain: bc,
Client: newDefaultHTTPClient(), Client: newDefaultHTTPClient(returnOracleRedirectionErrOn),
} }
} }
@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) (
chan *transaction.Transaction) { chan *transaction.Transaction) {
m := make(map[uint64]*responseWithSig) m := make(map[uint64]*responseWithSig)
ch := make(chan *transaction.Transaction, 5) ch := make(chan *transaction.Transaction, 5)
orcCfg := getOracleConfig(t, bc, walletPath, pass) orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool {
return strings.HasPrefix(address, "https://private")
})
orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m} orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m}
orcCfg.OnTransaction = saveTxToChan(ch) orcCfg.OnTransaction = saveTxToChan(ch)
orcCfg.URIValidator = func(u *url.URL) error {
if strings.HasPrefix(u.Host, "private") {
return errors.New("private network")
}
return nil
}
orc, err := oracle.NewOracle(orcCfg) orc, err := oracle.NewOracle(orcCfg)
require.NoError(t, err) require.NoError(t, err)
@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) {
func TestOracle_InvalidWallet(t *testing.T) { func TestOracle_InvalidWallet(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid")) _, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil))
require.Error(t, err) require.Error(t, err)
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one")) _, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil))
require.NoError(t, err) require.NoError(t, err)
} }
@ -551,6 +546,7 @@ type (
// httpClient implements oracle.HTTPClient with // httpClient implements oracle.HTTPClient with
// mocked URL or responses. // mocked URL or responses.
httpClient struct { httpClient struct {
returnOracleRedirectionErrOn func(address string) bool
responses map[string]testResponse responses map[string]testResponse
} }
@ -563,6 +559,9 @@ type (
// Get implements oracle.HTTPClient interface. // Get implements oracle.HTTPClient interface.
func (c *httpClient) Do(req *http.Request) (*http.Response, error) { func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) {
return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect)
}
resp, ok := c.responses[req.URL.String()] resp, ok := c.responses[req.URL.String()]
if ok { if ok {
return &http.Response{ return &http.Response{
@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
return nil, errors.New("request failed") return nil, errors.New("request failed")
} }
func newDefaultHTTPClient() oracle.HTTPClient { func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient {
return &httpClient{ return &httpClient{
returnOracleRedirectionErrOn: returnOracleRedirectionErrOn,
responses: map[string]testResponse{ responses: map[string]testResponse{
"https://get.1234": { "https://get.1234": {
code: http.StatusOK, code: http.StatusOK,

View file

@ -3,7 +3,6 @@ package oracle
import ( import (
"errors" "errors"
"net" "net"
"net/url"
) )
// reservedCIDRs is a list of ip addresses for private networks. // reservedCIDRs is a list of ip addresses for private networks.
@ -32,15 +31,15 @@ func init() {
} }
} }
func defaultURIValidator(u *url.URL) error { func resolveAndCheck(network string, address string) (*net.IPAddr, error) {
ip, err := net.ResolveIPAddr("ip", u.Hostname()) ip, err := net.ResolveIPAddr(network, address)
if err != nil { if err != nil {
return err return nil, err
} }
if isReserved(ip.IP) { if isReserved(ip.IP) {
return errors.New("IP is not global unicast") return nil, errors.New("IP is not global unicast")
} }
return nil return ip, nil
} }
func isReserved(ip net.IP) bool { func isReserved(ip net.IP) bool {

View file

@ -1,9 +1,11 @@
package oracle package oracle
import ( import (
"context"
"errors" "errors"
"fmt"
"net"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time" "time"
@ -80,7 +82,6 @@ type (
Chain Ledger Chain Ledger
ResponseHandler Broadcaster ResponseHandler Broadcaster
OnTransaction TxCallback OnTransaction TxCallback
URIValidator URIValidator
} }
// HTTPClient is an interface capable of doing oracle requests. // HTTPClient is an interface capable of doing oracle requests.
@ -99,8 +100,6 @@ type (
// TxCallback executes on new transactions when they are ready to be pooled. // TxCallback executes on new transactions when they are ready to be pooled.
TxCallback = func(tx *transaction.Transaction) error TxCallback = func(tx *transaction.Transaction) error
// URIValidator is used to check if provided URL is valid.
URIValidator = func(*url.URL) error
) )
const ( const (
@ -112,8 +111,15 @@ const (
// defaultRefreshInterval is default timeout for the failed request to be reprocessed. // defaultRefreshInterval is default timeout for the failed request to be reprocessed.
defaultRefreshInterval = time.Minute * 3 defaultRefreshInterval = time.Minute * 3
// maxRedirections is the number of allowed redirections for Oracle HTTPS request.
maxRedirections = 5
) )
// ErrRestrictedRedirect is returned when redirection to forbidden address occurs
// during Oracle response creation.
var ErrRestrictedRedirect = errors.New("oracle request redirection error")
// NewOracle returns new oracle instance. // NewOracle returns new oracle instance.
func NewOracle(cfg Config) (*Oracle, error) { func NewOracle(cfg Config) (*Oracle, error) {
o := &Oracle{ o := &Oracle{
@ -159,20 +165,39 @@ func NewOracle(cfg Config) (*Oracle, error) {
return nil, errors.New("no wallet account could be unlocked") return nil, errors.New("no wallet account could be unlocked")
} }
if o.Client == nil {
var client http.Client
client.Transport = &http.Transport{DisableKeepAlives: true}
client.Timeout = o.MainCfg.RequestTimeout
o.Client = &client
}
if o.ResponseHandler == nil { if o.ResponseHandler == nil {
o.ResponseHandler = defaultResponseHandler{} o.ResponseHandler = defaultResponseHandler{}
} }
if o.OnTransaction == nil { if o.OnTransaction == nil {
o.OnTransaction = func(*transaction.Transaction) error { return nil } o.OnTransaction = func(*transaction.Transaction) error { return nil }
} }
if o.URIValidator == nil { if o.Client == nil {
o.URIValidator = defaultURIValidator var client http.Client
client.Transport = &http.Transport{
DisableKeepAlives: true,
// Do not set DialTLSContext, so that DialContext will be used to establish the
// connection. After that TLS connection will be added to a persistent connection
// by standard library code and handshaking will be performed.
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
if !o.MainCfg.AllowPrivateHost {
ip, err := resolveAndCheck(network, address)
if err != nil {
return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err)
}
network = ip.Network()
address = ip.IP.String()
}
return net.Dial(network, address)
},
}
client.Timeout = o.MainCfg.RequestTimeout
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694
return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections)
}
return nil
}
o.Client = &client
} }
return o, nil return o, nil
} }

View file

@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
} else { } else {
switch u.Scheme { switch u.Scheme {
case "https": case "https":
if !o.MainCfg.AllowPrivateHost {
err = o.URIValidator(u)
if err != nil {
o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL))
resp.Code = transaction.Forbidden
break
}
}
httpReq, err := http.NewRequest("GET", req.Req.URL, nil) httpReq, err := http.NewRequest("GET", req.Req.URL, nil)
if err != nil { if err != nil {
o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err)) o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err))
@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
r, err := o.Client.Do(httpReq) r, err := o.Client.Do(httpReq)
if err != nil { if err != nil {
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err)) if errors.Is(err, ErrRestrictedRedirect) {
resp.Code = transaction.Forbidden
} else {
resp.Code = transaction.Error resp.Code = transaction.Error
}
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code))
break break
} }
switch r.StatusCode { switch r.StatusCode {