services: check Oracle response redirections

1. Move redirections check to the tcp level. Manually resolve request address
and create connection for the first suitable resolved address.
2. Remove URIValidator. Redirections checks are set in the custom http client,
so the user should take care of validation by himself when customizing the
client.
This commit is contained in:
AnnaShaleva 2022-03-02 20:22:26 +03:00 committed by Anna Shaleva
parent 26b76ed858
commit 537de18ac3
4 changed files with 62 additions and 42 deletions

View file

@ -8,7 +8,6 @@ import (
gio "io"
"io/ioutil"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
return res.Container
}
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config {
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config {
return oracle.Config{
Log: zaptest.NewLogger(t),
Network: netmode.UnitTestNet,
@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config
},
},
Chain: bc,
Client: newDefaultHTTPClient(),
Client: newDefaultHTTPClient(returnOracleRedirectionErrOn),
}
}
@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) (
chan *transaction.Transaction) {
m := make(map[uint64]*responseWithSig)
ch := make(chan *transaction.Transaction, 5)
orcCfg := getOracleConfig(t, bc, walletPath, pass)
orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool {
return strings.HasPrefix(address, "https://private")
})
orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m}
orcCfg.OnTransaction = saveTxToChan(ch)
orcCfg.URIValidator = func(u *url.URL) error {
if strings.HasPrefix(u.Host, "private") {
return errors.New("private network")
}
return nil
}
orc, err := oracle.NewOracle(orcCfg)
require.NoError(t, err)
@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) {
func TestOracle_InvalidWallet(t *testing.T) {
bc := newTestChain(t)
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid"))
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil))
require.Error(t, err)
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one"))
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil))
require.NoError(t, err)
}
@ -551,7 +546,8 @@ type (
// httpClient implements oracle.HTTPClient with
// mocked URL or responses.
httpClient struct {
responses map[string]testResponse
returnOracleRedirectionErrOn func(address string) bool
responses map[string]testResponse
}
testResponse struct {
@ -563,6 +559,9 @@ type (
// Get implements oracle.HTTPClient interface.
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) {
return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect)
}
resp, ok := c.responses[req.URL.String()]
if ok {
return &http.Response{
@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
return nil, errors.New("request failed")
}
func newDefaultHTTPClient() oracle.HTTPClient {
func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient {
return &httpClient{
returnOracleRedirectionErrOn: returnOracleRedirectionErrOn,
responses: map[string]testResponse{
"https://get.1234": {
code: http.StatusOK,

View file

@ -3,7 +3,6 @@ package oracle
import (
"errors"
"net"
"net/url"
)
// reservedCIDRs is a list of ip addresses for private networks.
@ -32,15 +31,15 @@ func init() {
}
}
func defaultURIValidator(u *url.URL) error {
ip, err := net.ResolveIPAddr("ip", u.Hostname())
func resolveAndCheck(network string, address string) (*net.IPAddr, error) {
ip, err := net.ResolveIPAddr(network, address)
if err != nil {
return err
return nil, err
}
if isReserved(ip.IP) {
return errors.New("IP is not global unicast")
return nil, errors.New("IP is not global unicast")
}
return nil
return ip, nil
}
func isReserved(ip net.IP) bool {

View file

@ -1,9 +1,11 @@
package oracle
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
@ -80,7 +82,6 @@ type (
Chain Ledger
ResponseHandler Broadcaster
OnTransaction TxCallback
URIValidator URIValidator
}
// HTTPClient is an interface capable of doing oracle requests.
@ -99,8 +100,6 @@ type (
// TxCallback executes on new transactions when they are ready to be pooled.
TxCallback = func(tx *transaction.Transaction) error
// URIValidator is used to check if provided URL is valid.
URIValidator = func(*url.URL) error
)
const (
@ -112,8 +111,15 @@ const (
// defaultRefreshInterval is default timeout for the failed request to be reprocessed.
defaultRefreshInterval = time.Minute * 3
// maxRedirections is the number of allowed redirections for Oracle HTTPS request.
maxRedirections = 5
)
// ErrRestrictedRedirect is returned when redirection to forbidden address occurs
// during Oracle response creation.
var ErrRestrictedRedirect = errors.New("oracle request redirection error")
// NewOracle returns new oracle instance.
func NewOracle(cfg Config) (*Oracle, error) {
o := &Oracle{
@ -159,20 +165,39 @@ func NewOracle(cfg Config) (*Oracle, error) {
return nil, errors.New("no wallet account could be unlocked")
}
if o.Client == nil {
var client http.Client
client.Transport = &http.Transport{DisableKeepAlives: true}
client.Timeout = o.MainCfg.RequestTimeout
o.Client = &client
}
if o.ResponseHandler == nil {
o.ResponseHandler = defaultResponseHandler{}
}
if o.OnTransaction == nil {
o.OnTransaction = func(*transaction.Transaction) error { return nil }
}
if o.URIValidator == nil {
o.URIValidator = defaultURIValidator
if o.Client == nil {
var client http.Client
client.Transport = &http.Transport{
DisableKeepAlives: true,
// Do not set DialTLSContext, so that DialContext will be used to establish the
// connection. After that TLS connection will be added to a persistent connection
// by standard library code and handshaking will be performed.
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
if !o.MainCfg.AllowPrivateHost {
ip, err := resolveAndCheck(network, address)
if err != nil {
return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err)
}
network = ip.Network()
address = ip.IP.String()
}
return net.Dial(network, address)
},
}
client.Timeout = o.MainCfg.RequestTimeout
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694
return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections)
}
return nil
}
o.Client = &client
}
return o, nil
}

View file

@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
} else {
switch u.Scheme {
case "https":
if !o.MainCfg.AllowPrivateHost {
err = o.URIValidator(u)
if err != nil {
o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL))
resp.Code = transaction.Forbidden
break
}
}
httpReq, err := http.NewRequest("GET", req.Req.URL, nil)
if err != nil {
o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err))
@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
httpReq.Header.Set("Content-Type", "application/json")
r, err := o.Client.Do(httpReq)
if err != nil {
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err))
resp.Code = transaction.Error
if errors.Is(err, ErrRestrictedRedirect) {
resp.Code = transaction.Forbidden
} else {
resp.Code = transaction.Error
}
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code))
break
}
switch r.StatusCode {