forked from TrueCloudLab/rclone
f471a7e3f5
Cookies are handled by cookiejar in memory with fshttp module through the entire session. One useful scenario is, with HTTP storage system where index server adds authentication cookie while redirecting to CDN for actual files. Also, it can be helpful to reuse fshttp in other storage systems requiring cookie.
319 lines
8.5 KiB
Go
319 lines
8.5 KiB
Go
// Package fshttp contains the common http parts of the config, Transport and Client
|
|
package fshttp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"net"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httputil"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ncw/rclone/fs"
|
|
"golang.org/x/net/publicsuffix"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const (
|
|
separatorReq = ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
|
|
separatorResp = "<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
|
)
|
|
|
|
var (
|
|
transport http.RoundTripper
|
|
noTransport sync.Once
|
|
tpsBucket *rate.Limiter // for limiting number of http transactions per second
|
|
cookieJar, _ = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
|
)
|
|
|
|
// StartHTTPTokenBucket starts the token bucket if necessary
|
|
func StartHTTPTokenBucket() {
|
|
if fs.Config.TPSLimit > 0 {
|
|
tpsBurst := fs.Config.TPSLimitBurst
|
|
if tpsBurst < 1 {
|
|
tpsBurst = 1
|
|
}
|
|
tpsBucket = rate.NewLimiter(rate.Limit(fs.Config.TPSLimit), tpsBurst)
|
|
fs.Infof(nil, "Starting HTTP transaction limiter: max %g transactions/s with burst %d", fs.Config.TPSLimit, tpsBurst)
|
|
}
|
|
}
|
|
|
|
// A net.Conn that sets a deadline for every Read or Write operation
|
|
type timeoutConn struct {
|
|
net.Conn
|
|
timeout time.Duration
|
|
}
|
|
|
|
// create a timeoutConn using the timeout
|
|
func newTimeoutConn(conn net.Conn, timeout time.Duration) (c *timeoutConn, err error) {
|
|
c = &timeoutConn{
|
|
Conn: conn,
|
|
timeout: timeout,
|
|
}
|
|
err = c.nudgeDeadline()
|
|
return
|
|
}
|
|
|
|
// Nudge the deadline for an idle timeout on by c.timeout if non-zero
|
|
func (c *timeoutConn) nudgeDeadline() (err error) {
|
|
if c.timeout == 0 {
|
|
return nil
|
|
}
|
|
when := time.Now().Add(c.timeout)
|
|
return c.Conn.SetDeadline(when)
|
|
}
|
|
|
|
// readOrWrite bytes doing idle timeouts
|
|
func (c *timeoutConn) readOrWrite(f func([]byte) (int, error), b []byte) (n int, err error) {
|
|
n, err = f(b)
|
|
// Don't nudge if no bytes or an error
|
|
if n == 0 || err != nil {
|
|
return
|
|
}
|
|
// Nudge the deadline on successful Read or Write
|
|
err = c.nudgeDeadline()
|
|
return
|
|
}
|
|
|
|
// Read bytes doing idle timeouts
|
|
func (c *timeoutConn) Read(b []byte) (n int, err error) {
|
|
return c.readOrWrite(c.Conn.Read, b)
|
|
}
|
|
|
|
// Write bytes doing idle timeouts
|
|
func (c *timeoutConn) Write(b []byte) (n int, err error) {
|
|
return c.readOrWrite(c.Conn.Write, b)
|
|
}
|
|
|
|
// setDefaults for a from b
|
|
//
|
|
// Copy the public members from b to a. We can't just use a struct
|
|
// copy as Transport contains a private mutex.
|
|
func setDefaults(a, b interface{}) {
|
|
pt := reflect.TypeOf(a)
|
|
t := pt.Elem()
|
|
va := reflect.ValueOf(a).Elem()
|
|
vb := reflect.ValueOf(b).Elem()
|
|
for i := 0; i < t.NumField(); i++ {
|
|
aField := va.Field(i)
|
|
// Set a from b if it is public
|
|
if aField.CanSet() {
|
|
bField := vb.Field(i)
|
|
aField.Set(bField)
|
|
}
|
|
}
|
|
}
|
|
|
|
// dial with context and timeouts
|
|
func dialContextTimeout(ctx context.Context, network, address string, ci *fs.ConfigInfo) (net.Conn, error) {
|
|
dialer := NewDialer(ci)
|
|
c, err := dialer.DialContext(ctx, network, address)
|
|
if err != nil {
|
|
return c, err
|
|
}
|
|
return newTimeoutConn(c, ci.Timeout)
|
|
}
|
|
|
|
// NewTransport returns an http.RoundTripper with the correct timeouts
|
|
func NewTransport(ci *fs.ConfigInfo) http.RoundTripper {
|
|
noTransport.Do(func() {
|
|
// Start with a sensible set of defaults then override.
|
|
// This also means we get new stuff when it gets added to go
|
|
t := new(http.Transport)
|
|
setDefaults(t, http.DefaultTransport.(*http.Transport))
|
|
t.Proxy = http.ProxyFromEnvironment
|
|
t.MaxIdleConnsPerHost = 2 * (ci.Checkers + ci.Transfers + 1)
|
|
t.MaxIdleConns = 2 * t.MaxIdleConnsPerHost
|
|
t.TLSHandshakeTimeout = ci.ConnectTimeout
|
|
t.ResponseHeaderTimeout = ci.Timeout
|
|
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: ci.InsecureSkipVerify}
|
|
t.DisableCompression = ci.NoGzip
|
|
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return dialContextTimeout(ctx, network, addr, ci)
|
|
}
|
|
t.IdleConnTimeout = 60 * time.Second
|
|
t.ExpectContinueTimeout = ci.ConnectTimeout
|
|
// Wrap that http.Transport in our own transport
|
|
transport = newTransport(ci, t)
|
|
})
|
|
return transport
|
|
}
|
|
|
|
// NewClient returns an http.Client with the correct timeouts
|
|
func NewClient(ci *fs.ConfigInfo) *http.Client {
|
|
transport := &http.Client{
|
|
Transport: NewTransport(ci),
|
|
}
|
|
if ci.Cookie {
|
|
transport.Jar = cookieJar
|
|
}
|
|
return transport
|
|
}
|
|
|
|
// Transport is a our http Transport which wraps an http.Transport
|
|
// * Sets the User Agent
|
|
// * Does logging
|
|
type Transport struct {
|
|
*http.Transport
|
|
dump fs.DumpFlags
|
|
filterRequest func(req *http.Request)
|
|
userAgent string
|
|
}
|
|
|
|
// newTransport wraps the http.Transport passed in and logs all
|
|
// roundtrips including the body if logBody is set.
|
|
func newTransport(ci *fs.ConfigInfo, transport *http.Transport) *Transport {
|
|
return &Transport{
|
|
Transport: transport,
|
|
dump: ci.Dump,
|
|
userAgent: ci.UserAgent,
|
|
}
|
|
}
|
|
|
|
// SetRequestFilter sets a filter to be used on each request
|
|
func (t *Transport) SetRequestFilter(f func(req *http.Request)) {
|
|
t.filterRequest = f
|
|
}
|
|
|
|
// A mutex to protect this map
|
|
var checkedHostMu sync.RWMutex
|
|
|
|
// A map of servers we have checked for time
|
|
var checkedHost = make(map[string]struct{}, 1)
|
|
|
|
// Check the server time is the same as ours, once for each server
|
|
func checkServerTime(req *http.Request, resp *http.Response) {
|
|
host := req.URL.Host
|
|
if req.Host != "" {
|
|
host = req.Host
|
|
}
|
|
checkedHostMu.RLock()
|
|
_, ok := checkedHost[host]
|
|
checkedHostMu.RUnlock()
|
|
if ok {
|
|
return
|
|
}
|
|
dateString := resp.Header.Get("Date")
|
|
if dateString == "" {
|
|
return
|
|
}
|
|
date, err := http.ParseTime(dateString)
|
|
if err != nil {
|
|
fs.Debugf(nil, "Couldn't parse Date: from server %s: %q: %v", host, dateString, err)
|
|
return
|
|
}
|
|
dt := time.Since(date)
|
|
const window = 5 * 60 * time.Second
|
|
if dt > window || dt < -window {
|
|
fs.Logf(nil, "Time may be set wrong - time from %q is %v different from this computer", host, dt)
|
|
}
|
|
checkedHostMu.Lock()
|
|
checkedHost[host] = struct{}{}
|
|
checkedHostMu.Unlock()
|
|
}
|
|
|
|
// cleanAuth gets rid of one authBuf header within the first 4k
|
|
func cleanAuth(buf, authBuf []byte) []byte {
|
|
// Find how much buffer to check
|
|
n := 4096
|
|
if len(buf) < n {
|
|
n = len(buf)
|
|
}
|
|
// See if there is an Authorization: header
|
|
i := bytes.Index(buf[:n], authBuf)
|
|
if i < 0 {
|
|
return buf
|
|
}
|
|
i += len(authBuf)
|
|
// Overwrite the next 4 chars with 'X'
|
|
for j := 0; i < len(buf) && j < 4; j++ {
|
|
if buf[i] == '\n' {
|
|
break
|
|
}
|
|
buf[i] = 'X'
|
|
i++
|
|
}
|
|
// Snip out to the next '\n'
|
|
j := bytes.IndexByte(buf[i:], '\n')
|
|
if j < 0 {
|
|
return buf[:i]
|
|
}
|
|
n = copy(buf[i:], buf[i+j:])
|
|
return buf[:i+n]
|
|
}
|
|
|
|
var authBufs = [][]byte{
|
|
[]byte("Authorization: "),
|
|
[]byte("X-Auth-Token: "),
|
|
}
|
|
|
|
// cleanAuths gets rid of all the possible Auth headers
|
|
func cleanAuths(buf []byte) []byte {
|
|
for _, authBuf := range authBufs {
|
|
buf = cleanAuth(buf, authBuf)
|
|
}
|
|
return buf
|
|
}
|
|
|
|
// RoundTrip implements the RoundTripper interface.
|
|
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
|
// Get transactions per second token first if limiting
|
|
if tpsBucket != nil {
|
|
tbErr := tpsBucket.Wait(req.Context())
|
|
if tbErr != nil {
|
|
fs.Errorf(nil, "HTTP token bucket error: %v", err)
|
|
}
|
|
}
|
|
// Force user agent
|
|
req.Header.Set("User-Agent", t.userAgent)
|
|
// Filter the request if required
|
|
if t.filterRequest != nil {
|
|
t.filterRequest(req)
|
|
}
|
|
// Logf request
|
|
if t.dump&(fs.DumpHeaders|fs.DumpBodies|fs.DumpAuth|fs.DumpRequests|fs.DumpResponses) != 0 {
|
|
buf, _ := httputil.DumpRequestOut(req, t.dump&(fs.DumpBodies|fs.DumpRequests) != 0)
|
|
if t.dump&fs.DumpAuth == 0 {
|
|
buf = cleanAuths(buf)
|
|
}
|
|
fs.Debugf(nil, "%s", separatorReq)
|
|
fs.Debugf(nil, "%s (req %p)", "HTTP REQUEST", req)
|
|
fs.Debugf(nil, "%s", string(buf))
|
|
fs.Debugf(nil, "%s", separatorReq)
|
|
}
|
|
// Do round trip
|
|
resp, err = t.Transport.RoundTrip(req)
|
|
// Logf response
|
|
if t.dump&(fs.DumpHeaders|fs.DumpBodies|fs.DumpAuth|fs.DumpRequests|fs.DumpResponses) != 0 {
|
|
fs.Debugf(nil, "%s", separatorResp)
|
|
fs.Debugf(nil, "%s (req %p)", "HTTP RESPONSE", req)
|
|
if err != nil {
|
|
fs.Debugf(nil, "Error: %v", err)
|
|
} else {
|
|
buf, _ := httputil.DumpResponse(resp, t.dump&(fs.DumpBodies|fs.DumpResponses) != 0)
|
|
fs.Debugf(nil, "%s", string(buf))
|
|
}
|
|
fs.Debugf(nil, "%s", separatorResp)
|
|
}
|
|
if err == nil {
|
|
checkServerTime(req, resp)
|
|
}
|
|
return resp, err
|
|
}
|
|
|
|
// NewDialer creates a net.Dialer structure with Timeout, Keepalive
|
|
// and LocalAddr set from rclone flags.
|
|
func NewDialer(ci *fs.ConfigInfo) *net.Dialer {
|
|
dialer := &net.Dialer{
|
|
Timeout: ci.ConnectTimeout,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
if ci.BindAddr != nil {
|
|
dialer.LocalAddr = &net.TCPAddr{IP: ci.BindAddr}
|
|
}
|
|
return dialer
|
|
}
|