services: improve Oracle redirection check

Move IP check to later stage and do not resolve URI manually.
This commit is contained in:
Anna Shaleva 2022-03-04 14:09:43 +03:00
parent 537de18ac3
commit 5ace840cc7
3 changed files with 85 additions and 41 deletions

View file

@ -1,8 +1,12 @@
package oracle package oracle
import ( import (
"errors" "fmt"
"net" "net"
"net/http"
"syscall"
"github.com/nspcc-dev/neo-go/pkg/config"
) )
// reservedCIDRs is a list of ip addresses for private networks. // reservedCIDRs is a list of ip addresses for private networks.
@ -31,17 +35,6 @@ func init() {
} }
} }
func resolveAndCheck(network string, address string) (*net.IPAddr, error) {
ip, err := net.ResolveIPAddr(network, address)
if err != nil {
return nil, err
}
if isReserved(ip.IP) {
return nil, errors.New("IP is not global unicast")
}
return ip, nil
}
func isReserved(ip net.IP) bool { func isReserved(ip net.IP) bool {
if !ip.IsGlobalUnicast() { if !ip.IsGlobalUnicast() {
return true return true
@ -53,3 +46,51 @@ func isReserved(ip net.IP) bool {
} }
return false return false
} }
func getDefaultClient(cfg config.OracleConfiguration) *http.Client {
d := &net.Dialer{}
if !cfg.AllowPrivateHost {
// Control is used after request URI is resolved and network connection (network
// file descriptor) is created, but right before the moment listening/dialing
// is started.
// `address` represents resolved IP address in the format of ip:port. `address`
// is presented in its final (resolved) form that was used directly for network
// connection establishing.
// Control is called for each item in the set of IP addresses got from request
// URI resolving. The first network connection with address that passes Control
// function will be used for further request processing. Network connection
// with address that failed Control will be ignored. If all the connections
// fail Control then the most relevant error (the one from the first address)
// will be returned after `Client.Do`.
d.Control = func(network, address string, c syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("%w: failed to split address %s: %s", ErrRestrictedRedirect, address, err)
}
ip := net.ParseIP(host)
if ip == nil {
return fmt.Errorf("%w: failed to parse IP address %s", ErrRestrictedRedirect, address)
}
if isReserved(ip) {
return fmt.Errorf("%w: IP is not global unicast", ErrRestrictedRedirect)
}
return nil
}
}
var client http.Client
client.Transport = &http.Transport{
DisableKeepAlives: true,
// Do not set DialTLSContext, so that DialContext will be used to establish the
// connection. After that TLS connection will be added to a persistent connection
// by standard library code and handshaking will be performed.
DialContext: d.DialContext,
}
client.Timeout = cfg.RequestTimeout
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694
return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections)
}
return nil
}
return &client
}

View file

@ -1,9 +1,13 @@
package oracle package oracle
import ( import (
"errors"
"net" "net"
"strings"
"testing" "testing"
"time"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -16,3 +20,30 @@ func TestIsReserved(t *testing.T) {
require.False(t, isReserved(net.IPv4(8, 8, 8, 8))) require.False(t, isReserved(net.IPv4(8, 8, 8, 8)))
} }
func TestDefaultClient_RestrictedRedirectErr(t *testing.T) {
cfg := config.OracleConfiguration{
AllowPrivateHost: false,
RequestTimeout: time.Second,
}
cl := getDefaultClient(cfg)
testCases := []string{
"http://localhost:8080",
"http://localhost",
"https://localhost:443",
"https://" + net.IPv4zero.String(),
"https://" + net.IPv4(10, 0, 0, 1).String(),
"https://" + net.IPv4(192, 168, 0, 1).String(),
"https://[" + net.IPv6interfacelocalallnodes.String() + "]",
"https://[" + net.IPv6loopback.String() + "]",
}
for _, c := range testCases {
t.Run(c, func(t *testing.T) {
_, err := cl.Get(c)
require.Error(t, err)
require.True(t, errors.Is(err, ErrRestrictedRedirect), err)
require.True(t, strings.Contains(err.Error(), "IP is not global unicast"), err)
})
}
}

View file

@ -1,10 +1,7 @@
package oracle package oracle
import ( import (
"context"
"errors" "errors"
"fmt"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -172,32 +169,7 @@ func NewOracle(cfg Config) (*Oracle, error) {
o.OnTransaction = func(*transaction.Transaction) error { return nil } o.OnTransaction = func(*transaction.Transaction) error { return nil }
} }
if o.Client == nil { if o.Client == nil {
var client http.Client o.Client = getDefaultClient(o.MainCfg)
client.Transport = &http.Transport{
DisableKeepAlives: true,
// Do not set DialTLSContext, so that DialContext will be used to establish the
// connection. After that TLS connection will be added to a persistent connection
// by standard library code and handshaking will be performed.
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
if !o.MainCfg.AllowPrivateHost {
ip, err := resolveAndCheck(network, address)
if err != nil {
return nil, fmt.Errorf("%w: address %s failed validation: %s", ErrRestrictedRedirect, address, err)
}
network = ip.Network()
address = ip.IP.String()
}
return net.Dial(network, address)
},
}
client.Timeout = o.MainCfg.RequestTimeout
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirections { // from https://github.com/neo-project/neo-modules/pull/694
return fmt.Errorf("%w: %d redirections are reached", ErrRestrictedRedirect, maxRedirections)
}
return nil
}
o.Client = &client
} }
return o, nil return o, nil
} }