forked from TrueCloudLab/neoneo-go
services: check Oracle response redirections
1. Move redirections check to the tcp level. Manually resolve request address and create connection for the first suitable resolved address. 2. Remove URIValidator. Redirections checks are set in the custom http client, so the user should take care of validation by himself when customizing the client.
This commit is contained in:
parent
26b76ed858
commit
537de18ac3
4 changed files with 62 additions and 42 deletions
|
@ -8,7 +8,6 @@ import (
|
||||||
gio "io"
|
gio "io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
|
||||||
return res.Container
|
return res.Container
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config {
|
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config {
|
||||||
return oracle.Config{
|
return oracle.Config{
|
||||||
Log: zaptest.NewLogger(t),
|
Log: zaptest.NewLogger(t),
|
||||||
Network: netmode.UnitTestNet,
|
Network: netmode.UnitTestNet,
|
||||||
|
@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Chain: bc,
|
Chain: bc,
|
||||||
Client: newDefaultHTTPClient(),
|
Client: newDefaultHTTPClient(returnOracleRedirectionErrOn),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) (
|
||||||
chan *transaction.Transaction) {
|
chan *transaction.Transaction) {
|
||||||
m := make(map[uint64]*responseWithSig)
|
m := make(map[uint64]*responseWithSig)
|
||||||
ch := make(chan *transaction.Transaction, 5)
|
ch := make(chan *transaction.Transaction, 5)
|
||||||
orcCfg := getOracleConfig(t, bc, walletPath, pass)
|
orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool {
|
||||||
|
return strings.HasPrefix(address, "https://private")
|
||||||
|
})
|
||||||
orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m}
|
orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m}
|
||||||
orcCfg.OnTransaction = saveTxToChan(ch)
|
orcCfg.OnTransaction = saveTxToChan(ch)
|
||||||
orcCfg.URIValidator = func(u *url.URL) error {
|
|
||||||
if strings.HasPrefix(u.Host, "private") {
|
|
||||||
return errors.New("private network")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
orc, err := oracle.NewOracle(orcCfg)
|
orc, err := oracle.NewOracle(orcCfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) {
|
||||||
func TestOracle_InvalidWallet(t *testing.T) {
|
func TestOracle_InvalidWallet(t *testing.T) {
|
||||||
bc := newTestChain(t)
|
bc := newTestChain(t)
|
||||||
|
|
||||||
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid"))
|
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one"))
|
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -551,6 +546,7 @@ type (
|
||||||
// httpClient implements oracle.HTTPClient with
|
// httpClient implements oracle.HTTPClient with
|
||||||
// mocked URL or responses.
|
// mocked URL or responses.
|
||||||
httpClient struct {
|
httpClient struct {
|
||||||
|
returnOracleRedirectionErrOn func(address string) bool
|
||||||
responses map[string]testResponse
|
responses map[string]testResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,6 +559,9 @@ type (
|
||||||
|
|
||||||
// Get implements oracle.HTTPClient interface.
|
// Get implements oracle.HTTPClient interface.
|
||||||
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) {
|
||||||
|
return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect)
|
||||||
|
}
|
||||||
resp, ok := c.responses[req.URL.String()]
|
resp, ok := c.responses[req.URL.String()]
|
||||||
if ok {
|
if ok {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
return nil, errors.New("request failed")
|
return nil, errors.New("request failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultHTTPClient() oracle.HTTPClient {
|
func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient {
|
||||||
return &httpClient{
|
return &httpClient{
|
||||||
|
returnOracleRedirectionErrOn: returnOracleRedirectionErrOn,
|
||||||
responses: map[string]testResponse{
|
responses: map[string]testResponse{
|
||||||
"https://get.1234": {
|
"https://get.1234": {
|
||||||
code: http.StatusOK,
|
code: http.StatusOK,
|
||||||
|
|
|
@ -3,7 +3,6 @@ package oracle
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// reservedCIDRs is a list of ip addresses for private networks.
|
// reservedCIDRs is a list of ip addresses for private networks.
|
||||||
|
@ -32,15 +31,15 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultURIValidator(u *url.URL) error {
|
func resolveAndCheck(network string, address string) (*net.IPAddr, error) {
|
||||||
ip, err := net.ResolveIPAddr("ip", u.Hostname())
|
ip, err := net.ResolveIPAddr(network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if isReserved(ip.IP) {
|
if isReserved(ip.IP) {
|
||||||
return errors.New("IP is not global unicast")
|
return nil, errors.New("IP is not global unicast")
|
||||||
}
|
}
|
||||||
return nil
|
return ip, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isReserved(ip net.IP) bool {
|
func isReserved(ip net.IP) bool {
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package oracle
|
package oracle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -80,7 +82,6 @@ type (
|
||||||
Chain Ledger
|
Chain Ledger
|
||||||
ResponseHandler Broadcaster
|
ResponseHandler Broadcaster
|
||||||
OnTransaction TxCallback
|
OnTransaction TxCallback
|
||||||
URIValidator URIValidator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClient is an interface capable of doing oracle requests.
|
// HTTPClient is an interface capable of doing oracle requests.
|
||||||
|
@ -99,8 +100,6 @@ type (
|
||||||
|
|
||||||
// TxCallback executes on new transactions when they are ready to be pooled.
|
// TxCallback executes on new transactions when they are ready to be pooled.
|
||||||
TxCallback = func(tx *transaction.Transaction) error
|
TxCallback = func(tx *transaction.Transaction) error
|
||||||
// URIValidator is used to check if provided URL is valid.
|
|
||||||
URIValidator = func(*url.URL) error
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -112,8 +111,15 @@ const (
|
||||||
|
|
||||||
// defaultRefreshInterval is default timeout for the failed request to be reprocessed.
|
// defaultRefreshInterval is default timeout for the failed request to be reprocessed.
|
||||||
defaultRefreshInterval = time.Minute * 3
|
defaultRefreshInterval = time.Minute * 3
|
||||||
|
|
||||||
|
// maxRedirections is the number of allowed redirections for Oracle HTTPS request.
|
||||||
|
maxRedirections = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrRestrictedRedirect is returned when redirection to forbidden address occurs
|
||||||
|
// during Oracle response creation.
|
||||||
|
var ErrRestrictedRedirect = errors.New("oracle request redirection error")
|
||||||
|
|
||||||
// NewOracle returns new oracle instance.
|
// NewOracle returns new oracle instance.
|
||||||
func NewOracle(cfg Config) (*Oracle, error) {
|
func NewOracle(cfg Config) (*Oracle, error) {
|
||||||
o := &Oracle{
|
o := &Oracle{
|
||||||
|
@ -159,20 +165,39 @@ func NewOracle(cfg Config) (*Oracle, error) {
|
||||||
return nil, errors.New("no wallet account could be unlocked")
|
return nil, errors.New("no wallet account could be unlocked")
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.Client == nil {
|
|
||||||
var client http.Client
|
|
||||||
client.Transport = &http.Transport{DisableKeepAlives: true}
|
|
||||||
client.Timeout = o.MainCfg.RequestTimeout
|
|
||||||
o.Client = &client
|
|
||||||
}
|
|
||||||
if o.ResponseHandler == nil {
|
if o.ResponseHandler == nil {
|
||||||
o.ResponseHandler = defaultResponseHandler{}
|
o.ResponseHandler = defaultResponseHandler{}
|
||||||
}
|
}
|
||||||
if o.OnTransaction == nil {
|
if o.OnTransaction == nil {
|
||||||
o.OnTransaction = func(*transaction.Transaction) error { return nil }
|
o.OnTransaction = func(*transaction.Transaction) error { return nil }
|
||||||
}
|
}
|
||||||
if o.URIValidator == nil {
|
if o.Client == nil {
|
||||||
o.URIValidator = defaultURIValidator
|
var client http.Client
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
// Do not set DialTLSContext, so that DialContext will be used to establish the
|
||||||
|
// connection. After that TLS connection will be added to a persistent connection
|
||||||
|
// by standard library code and handshaking will be performed.
|
||||||
|
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if !o.MainCfg.AllowPrivateHost {
|
||||||
|
ip, err := resolveAndCheck(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err)
|
||||||
|
}
|
||||||
|
network = ip.Network()
|
||||||
|
address = ip.IP.String()
|
||||||
|
}
|
||||||
|
return net.Dial(network, address)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client.Timeout = o.MainCfg.RequestTimeout
|
||||||
|
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
|
if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694
|
||||||
|
return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
o.Client = &client
|
||||||
}
|
}
|
||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
||||||
} else {
|
} else {
|
||||||
switch u.Scheme {
|
switch u.Scheme {
|
||||||
case "https":
|
case "https":
|
||||||
if !o.MainCfg.AllowPrivateHost {
|
|
||||||
err = o.URIValidator(u)
|
|
||||||
if err != nil {
|
|
||||||
o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL))
|
|
||||||
resp.Code = transaction.Forbidden
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
httpReq, err := http.NewRequest("GET", req.Req.URL, nil)
|
httpReq, err := http.NewRequest("GET", req.Req.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err))
|
o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err))
|
||||||
|
@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
r, err := o.Client.Do(httpReq)
|
r, err := o.Client.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err))
|
if errors.Is(err, ErrRestrictedRedirect) {
|
||||||
|
resp.Code = transaction.Forbidden
|
||||||
|
} else {
|
||||||
resp.Code = transaction.Error
|
resp.Code = transaction.Error
|
||||||
|
}
|
||||||
|
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
switch r.StatusCode {
|
switch r.StatusCode {
|
||||||
|
|
Loading…
Reference in a new issue