Merge pull request #2383 from nspcc-dev/oracle-redirection
services: check Oracle response redirections
This commit is contained in:
commit
6ece74a7c7
5 changed files with 113 additions and 49 deletions
|
@ -8,7 +8,6 @@ import (
|
||||||
gio "io"
|
gio "io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -178,7 +177,7 @@ func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
|
||||||
return res.Container
|
return res.Container
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config {
|
func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string, returnOracleRedirectionErrOn func(address string) bool) oracle.Config {
|
||||||
return oracle.Config{
|
return oracle.Config{
|
||||||
Log: zaptest.NewLogger(t),
|
Log: zaptest.NewLogger(t),
|
||||||
Network: netmode.UnitTestNet,
|
Network: netmode.UnitTestNet,
|
||||||
|
@ -191,7 +190,7 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Chain: bc,
|
Chain: bc,
|
||||||
Client: newDefaultHTTPClient(),
|
Client: newDefaultHTTPClient(returnOracleRedirectionErrOn),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,15 +201,11 @@ func getTestOracle(t *testing.T, bc *Blockchain, walletPath, pass string) (
|
||||||
chan *transaction.Transaction) {
|
chan *transaction.Transaction) {
|
||||||
m := make(map[uint64]*responseWithSig)
|
m := make(map[uint64]*responseWithSig)
|
||||||
ch := make(chan *transaction.Transaction, 5)
|
ch := make(chan *transaction.Transaction, 5)
|
||||||
orcCfg := getOracleConfig(t, bc, walletPath, pass)
|
orcCfg := getOracleConfig(t, bc, walletPath, pass, func(address string) bool {
|
||||||
|
return strings.HasPrefix(address, "https://private")
|
||||||
|
})
|
||||||
orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m}
|
orcCfg.ResponseHandler = &saveToMapBroadcaster{m: m}
|
||||||
orcCfg.OnTransaction = saveTxToChan(ch)
|
orcCfg.OnTransaction = saveTxToChan(ch)
|
||||||
orcCfg.URIValidator = func(u *url.URL) error {
|
|
||||||
if strings.HasPrefix(u.Host, "private") {
|
|
||||||
return errors.New("private network")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
orc, err := oracle.NewOracle(orcCfg)
|
orc, err := oracle.NewOracle(orcCfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -255,10 +250,10 @@ func TestCreateResponseTx(t *testing.T) {
|
||||||
func TestOracle_InvalidWallet(t *testing.T) {
|
func TestOracle_InvalidWallet(t *testing.T) {
|
||||||
bc := newTestChain(t)
|
bc := newTestChain(t)
|
||||||
|
|
||||||
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid"))
|
_, err := oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "invalid", nil))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one"))
|
_, err = oracle.NewOracle(getOracleConfig(t, bc, "./testdata/oracle1.json", "one", nil))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -551,7 +546,8 @@ type (
|
||||||
// httpClient implements oracle.HTTPClient with
|
// httpClient implements oracle.HTTPClient with
|
||||||
// mocked URL or responses.
|
// mocked URL or responses.
|
||||||
httpClient struct {
|
httpClient struct {
|
||||||
responses map[string]testResponse
|
returnOracleRedirectionErrOn func(address string) bool
|
||||||
|
responses map[string]testResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
testResponse struct {
|
testResponse struct {
|
||||||
|
@ -563,6 +559,9 @@ type (
|
||||||
|
|
||||||
// Get implements oracle.HTTPClient interface.
|
// Get implements oracle.HTTPClient interface.
|
||||||
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
if c.returnOracleRedirectionErrOn != nil && c.returnOracleRedirectionErrOn(req.URL.String()) {
|
||||||
|
return nil, fmt.Errorf("%w: private network", oracle.ErrRestrictedRedirect)
|
||||||
|
}
|
||||||
resp, ok := c.responses[req.URL.String()]
|
resp, ok := c.responses[req.URL.String()]
|
||||||
if ok {
|
if ok {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -576,8 +575,9 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
return nil, errors.New("request failed")
|
return nil, errors.New("request failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultHTTPClient() oracle.HTTPClient {
|
func newDefaultHTTPClient(returnOracleRedirectionErrOn func(address string) bool) oracle.HTTPClient {
|
||||||
return &httpClient{
|
return &httpClient{
|
||||||
|
returnOracleRedirectionErrOn: returnOracleRedirectionErrOn,
|
||||||
responses: map[string]testResponse{
|
responses: map[string]testResponse{
|
||||||
"https://get.1234": {
|
"https://get.1234": {
|
||||||
code: http.StatusOK,
|
code: http.StatusOK,
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
package oracle
|
package oracle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/http"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// reservedCIDRs is a list of ip addresses for private networks.
|
// reservedCIDRs is a list of ip addresses for private networks.
|
||||||
|
@ -32,17 +35,6 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultURIValidator(u *url.URL) error {
|
|
||||||
ip, err := net.ResolveIPAddr("ip", u.Hostname())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if isReserved(ip.IP) {
|
|
||||||
return errors.New("IP is not global unicast")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isReserved(ip net.IP) bool {
|
func isReserved(ip net.IP) bool {
|
||||||
if !ip.IsGlobalUnicast() {
|
if !ip.IsGlobalUnicast() {
|
||||||
return true
|
return true
|
||||||
|
@ -54,3 +46,51 @@ func isReserved(ip net.IP) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getDefaultClient(cfg config.OracleConfiguration) *http.Client {
|
||||||
|
d := &net.Dialer{}
|
||||||
|
if !cfg.AllowPrivateHost {
|
||||||
|
// Control is used after request URI is resolved and network connection (network
|
||||||
|
// file descriptor) is created, but right before the moment listening/dialing
|
||||||
|
// is started.
|
||||||
|
// `address` represents resolved IP address in the format of ip:port. `address`
|
||||||
|
// is presented in its final (resolved) form that was used directly for network
|
||||||
|
// connection establishing.
|
||||||
|
// Control is called for each item in the set of IP addresses got from request
|
||||||
|
// URI resolving. The first network connection with address that passes Control
|
||||||
|
// function will be used for further request processing. Network connection
|
||||||
|
// with address that failed Control will be ignored. If all the connections
|
||||||
|
// fail Control then the most relevant error (the one from the first address)
|
||||||
|
// will be returned after `Client.Do`.
|
||||||
|
d.Control = func(network, address string, c syscall.RawConn) error {
|
||||||
|
host, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: failed to split address %s: %s", ErrRestrictedRedirect, address, err)
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip == nil {
|
||||||
|
return fmt.Errorf("%w: failed to parse IP address %s", ErrRestrictedRedirect, address)
|
||||||
|
}
|
||||||
|
if isReserved(ip) {
|
||||||
|
return fmt.Errorf("%w: IP is not global unicast", ErrRestrictedRedirect)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var client http.Client
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
// Do not set DialTLSContext, so that DialContext will be used to establish the
|
||||||
|
// connection. After that TLS connection will be added to a persistent connection
|
||||||
|
// by standard library code and handshaking will be performed.
|
||||||
|
DialContext: d.DialContext,
|
||||||
|
}
|
||||||
|
client.Timeout = cfg.RequestTimeout
|
||||||
|
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
|
if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694
|
||||||
|
return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &client
|
||||||
|
}
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
package oracle
|
package oracle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,3 +20,30 @@ func TestIsReserved(t *testing.T) {
|
||||||
|
|
||||||
require.False(t, isReserved(net.IPv4(8, 8, 8, 8)))
|
require.False(t, isReserved(net.IPv4(8, 8, 8, 8)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultClient_RestrictedRedirectErr(t *testing.T) {
|
||||||
|
cfg := config.OracleConfiguration{
|
||||||
|
AllowPrivateHost: false,
|
||||||
|
RequestTimeout: time.Second,
|
||||||
|
}
|
||||||
|
cl := getDefaultClient(cfg)
|
||||||
|
|
||||||
|
testCases := []string{
|
||||||
|
"http://localhost:8080",
|
||||||
|
"http://localhost",
|
||||||
|
"https://localhost:443",
|
||||||
|
"https://" + net.IPv4zero.String(),
|
||||||
|
"https://" + net.IPv4(10, 0, 0, 1).String(),
|
||||||
|
"https://" + net.IPv4(192, 168, 0, 1).String(),
|
||||||
|
"https://[" + net.IPv6interfacelocalallnodes.String() + "]",
|
||||||
|
"https://[" + net.IPv6loopback.String() + "]",
|
||||||
|
}
|
||||||
|
for _, c := range testCases {
|
||||||
|
t.Run(c, func(t *testing.T) {
|
||||||
|
_, err := cl.Get(c)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, ErrRestrictedRedirect), err)
|
||||||
|
require.True(t, strings.Contains(err.Error(), "IP is not global unicast"), err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package oracle
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -80,7 +79,6 @@ type (
|
||||||
Chain Ledger
|
Chain Ledger
|
||||||
ResponseHandler Broadcaster
|
ResponseHandler Broadcaster
|
||||||
OnTransaction TxCallback
|
OnTransaction TxCallback
|
||||||
URIValidator URIValidator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClient is an interface capable of doing oracle requests.
|
// HTTPClient is an interface capable of doing oracle requests.
|
||||||
|
@ -99,8 +97,6 @@ type (
|
||||||
|
|
||||||
// TxCallback executes on new transactions when they are ready to be pooled.
|
// TxCallback executes on new transactions when they are ready to be pooled.
|
||||||
TxCallback = func(tx *transaction.Transaction) error
|
TxCallback = func(tx *transaction.Transaction) error
|
||||||
// URIValidator is used to check if provided URL is valid.
|
|
||||||
URIValidator = func(*url.URL) error
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -112,8 +108,15 @@ const (
|
||||||
|
|
||||||
// defaultRefreshInterval is default timeout for the failed request to be reprocessed.
|
// defaultRefreshInterval is default timeout for the failed request to be reprocessed.
|
||||||
defaultRefreshInterval = time.Minute * 3
|
defaultRefreshInterval = time.Minute * 3
|
||||||
|
|
||||||
|
// maxRedirections is the number of allowed redirections for Oracle HTTPS request.
|
||||||
|
maxRedirections = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrRestrictedRedirect is returned when redirection to forbidden address occurs
|
||||||
|
// during Oracle response creation.
|
||||||
|
var ErrRestrictedRedirect = errors.New("oracle request redirection error")
|
||||||
|
|
||||||
// NewOracle returns new oracle instance.
|
// NewOracle returns new oracle instance.
|
||||||
func NewOracle(cfg Config) (*Oracle, error) {
|
func NewOracle(cfg Config) (*Oracle, error) {
|
||||||
o := &Oracle{
|
o := &Oracle{
|
||||||
|
@ -159,20 +162,14 @@ func NewOracle(cfg Config) (*Oracle, error) {
|
||||||
return nil, errors.New("no wallet account could be unlocked")
|
return nil, errors.New("no wallet account could be unlocked")
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.Client == nil {
|
|
||||||
var client http.Client
|
|
||||||
client.Transport = &http.Transport{DisableKeepAlives: true}
|
|
||||||
client.Timeout = o.MainCfg.RequestTimeout
|
|
||||||
o.Client = &client
|
|
||||||
}
|
|
||||||
if o.ResponseHandler == nil {
|
if o.ResponseHandler == nil {
|
||||||
o.ResponseHandler = defaultResponseHandler{}
|
o.ResponseHandler = defaultResponseHandler{}
|
||||||
}
|
}
|
||||||
if o.OnTransaction == nil {
|
if o.OnTransaction == nil {
|
||||||
o.OnTransaction = func(*transaction.Transaction) error { return nil }
|
o.OnTransaction = func(*transaction.Transaction) error { return nil }
|
||||||
}
|
}
|
||||||
if o.URIValidator == nil {
|
if o.Client == nil {
|
||||||
o.URIValidator = defaultURIValidator
|
o.Client = getDefaultClient(o.MainCfg)
|
||||||
}
|
}
|
||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,14 +120,6 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
||||||
} else {
|
} else {
|
||||||
switch u.Scheme {
|
switch u.Scheme {
|
||||||
case "https":
|
case "https":
|
||||||
if !o.MainCfg.AllowPrivateHost {
|
|
||||||
err = o.URIValidator(u)
|
|
||||||
if err != nil {
|
|
||||||
o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL))
|
|
||||||
resp.Code = transaction.Forbidden
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
httpReq, err := http.NewRequest("GET", req.Req.URL, nil)
|
httpReq, err := http.NewRequest("GET", req.Req.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err))
|
o.Log.Warn("failed to create http request", zap.String("url", req.Req.URL), zap.Error(err))
|
||||||
|
@ -138,8 +130,12 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
r, err := o.Client.Do(httpReq)
|
r, err := o.Client.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err))
|
if errors.Is(err, ErrRestrictedRedirect) {
|
||||||
resp.Code = transaction.Error
|
resp.Code = transaction.Forbidden
|
||||||
|
} else {
|
||||||
|
resp.Code = transaction.Error
|
||||||
|
}
|
||||||
|
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err), zap.Stringer("code", resp.Code))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
switch r.StatusCode {
|
switch r.StatusCode {
|
||||||
|
|
Loading…
Add table
Reference in a new issue