lib/structs: factor reflection based structure manipulation into a library
This commit is contained in:
parent
6807b0e42f
commit
b5c654a100
4 changed files with 171 additions and 52 deletions
|
@ -12,11 +12,11 @@ import (
|
|||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httputil"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/rclone/rclone/lib/structs"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
@ -92,25 +92,6 @@ func (c *timeoutConn) Write(b []byte) (n int, err error) {
|
|||
return c.readOrWrite(c.Conn.Write, b)
|
||||
}
|
||||
|
||||
// setDefaults for a from b
|
||||
//
|
||||
// Copy the public members from b to a. We can't just use a struct
|
||||
// copy as Transport contains a private mutex.
|
||||
func setDefaults(a, b interface{}) {
|
||||
pt := reflect.TypeOf(a)
|
||||
t := pt.Elem()
|
||||
va := reflect.ValueOf(a).Elem()
|
||||
vb := reflect.ValueOf(b).Elem()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
aField := va.Field(i)
|
||||
// Set a from b if it is public
|
||||
if aField.CanSet() {
|
||||
bField := vb.Field(i)
|
||||
aField.Set(bField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dial with context and timeouts
|
||||
func dialContextTimeout(ctx context.Context, network, address string, ci *fs.ConfigInfo) (net.Conn, error) {
|
||||
dialer := NewDialer(ci)
|
||||
|
@ -134,7 +115,7 @@ func NewTransportCustom(ci *fs.ConfigInfo, customize func(*http.Transport)) http
|
|||
// Start with a sensible set of defaults then override.
|
||||
// This also means we get new stuff when it gets added to go
|
||||
t := new(http.Transport)
|
||||
setDefaults(t, http.DefaultTransport.(*http.Transport))
|
||||
structs.SetDefaults(t, http.DefaultTransport.(*http.Transport))
|
||||
t.Proxy = http.ProxyFromEnvironment
|
||||
t.MaxIdleConnsPerHost = 2 * (ci.Checkers + ci.Transfers + 1)
|
||||
t.MaxIdleConns = 2 * t.MaxIdleConnsPerHost
|
||||
|
|
|
@ -1,42 +1,11 @@
|
|||
package fshttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// returns the "%p" representation of the thing passed in
|
||||
func ptr(p interface{}) string {
|
||||
return fmt.Sprintf("%p", p)
|
||||
}
|
||||
|
||||
func TestSetDefaults(t *testing.T) {
|
||||
old := http.DefaultTransport.(*http.Transport)
|
||||
newT := new(http.Transport)
|
||||
setDefaults(newT, old)
|
||||
// Can't use assert.Equal or reflect.DeepEqual for this as it has functions in
|
||||
// Check functions by comparing the "%p" representations of them
|
||||
assert.Equal(t, ptr(old.Proxy), ptr(newT.Proxy), "when checking .Proxy")
|
||||
assert.Equal(t, ptr(old.DialContext), ptr(newT.DialContext), "when checking .DialContext")
|
||||
// Check the other public fields
|
||||
assert.Equal(t, ptr(old.Dial), ptr(newT.Dial), "when checking .Dial")
|
||||
assert.Equal(t, ptr(old.DialTLS), ptr(newT.DialTLS), "when checking .DialTLS")
|
||||
assert.Equal(t, old.TLSClientConfig, newT.TLSClientConfig, "when checking .TLSClientConfig")
|
||||
assert.Equal(t, old.TLSHandshakeTimeout, newT.TLSHandshakeTimeout, "when checking .TLSHandshakeTimeout")
|
||||
assert.Equal(t, old.DisableKeepAlives, newT.DisableKeepAlives, "when checking .DisableKeepAlives")
|
||||
assert.Equal(t, old.DisableCompression, newT.DisableCompression, "when checking .DisableCompression")
|
||||
assert.Equal(t, old.MaxIdleConns, newT.MaxIdleConns, "when checking .MaxIdleConns")
|
||||
assert.Equal(t, old.MaxIdleConnsPerHost, newT.MaxIdleConnsPerHost, "when checking .MaxIdleConnsPerHost")
|
||||
assert.Equal(t, old.IdleConnTimeout, newT.IdleConnTimeout, "when checking .IdleConnTimeout")
|
||||
assert.Equal(t, old.ResponseHeaderTimeout, newT.ResponseHeaderTimeout, "when checking .ResponseHeaderTimeout")
|
||||
assert.Equal(t, old.ExpectContinueTimeout, newT.ExpectContinueTimeout, "when checking .ExpectContinueTimeout")
|
||||
assert.Equal(t, old.TLSNextProto, newT.TLSNextProto, "when checking .TLSNextProto")
|
||||
assert.Equal(t, old.MaxResponseHeaderBytes, newT.MaxResponseHeaderBytes, "when checking .MaxResponseHeaderBytes")
|
||||
}
|
||||
|
||||
func TestCleanAuth(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
in string
|
||||
|
|
57
lib/structs/structs.go
Normal file
57
lib/structs/structs.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Package structs is for manipulating structures with reflection
|
||||
package structs
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// SetFrom sets the public members of a from b
|
||||
//
|
||||
// a and b should be pointers to structs
|
||||
//
|
||||
// a can be a different type from b
|
||||
//
|
||||
// Only the Fields which have the same name and assignable type on a
|
||||
// and b will be set.
|
||||
//
|
||||
// This is useful for copying between almost identical structures that
|
||||
// are requently present in auto generated code for cloud storage
|
||||
// interfaces.
|
||||
func SetFrom(a, b interface{}) {
|
||||
ta := reflect.TypeOf(a).Elem()
|
||||
tb := reflect.TypeOf(b).Elem()
|
||||
va := reflect.ValueOf(a).Elem()
|
||||
vb := reflect.ValueOf(b).Elem()
|
||||
for i := 0; i < tb.NumField(); i++ {
|
||||
bField := vb.Field(i)
|
||||
tbField := tb.Field(i)
|
||||
name := tbField.Name
|
||||
aField := va.FieldByName(name)
|
||||
taField, found := ta.FieldByName(name)
|
||||
if found && aField.IsValid() && bField.IsValid() && aField.CanSet() && tbField.Type.AssignableTo(taField.Type) {
|
||||
aField.Set(bField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetDefaults for a from b
|
||||
//
|
||||
// a and b should be pointers to the same kind of struct
|
||||
//
|
||||
// This copies the public members only from b to a. This is useful if
|
||||
// you can't just use a struct copy because it contains a private
|
||||
// mutex, eg as http.Transport.
|
||||
func SetDefaults(a, b interface{}) {
|
||||
pt := reflect.TypeOf(a)
|
||||
t := pt.Elem()
|
||||
va := reflect.ValueOf(a).Elem()
|
||||
vb := reflect.ValueOf(b).Elem()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
aField := va.Field(i)
|
||||
// Set a from b if it is public
|
||||
if aField.CanSet() {
|
||||
bField := vb.Field(i)
|
||||
aField.Set(bField)
|
||||
}
|
||||
}
|
||||
}
|
112
lib/structs/structs_test.go
Normal file
112
lib/structs/structs_test.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package structs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// returns the "%p" representation of the thing passed in
|
||||
func ptr(p interface{}) string {
|
||||
return fmt.Sprintf("%p", p)
|
||||
}
|
||||
|
||||
func TestSetDefaults(t *testing.T) {
|
||||
old := http.DefaultTransport.(*http.Transport)
|
||||
newT := new(http.Transport)
|
||||
SetDefaults(newT, old)
|
||||
// Can't use assert.Equal or reflect.DeepEqual for this as it has functions in
|
||||
// Check functions by comparing the "%p" representations of them
|
||||
assert.Equal(t, ptr(old.Proxy), ptr(newT.Proxy), "when checking .Proxy")
|
||||
assert.Equal(t, ptr(old.DialContext), ptr(newT.DialContext), "when checking .DialContext")
|
||||
// Check the other public fields
|
||||
assert.Equal(t, ptr(old.Dial), ptr(newT.Dial), "when checking .Dial")
|
||||
assert.Equal(t, ptr(old.DialTLS), ptr(newT.DialTLS), "when checking .DialTLS")
|
||||
assert.Equal(t, old.TLSClientConfig, newT.TLSClientConfig, "when checking .TLSClientConfig")
|
||||
assert.Equal(t, old.TLSHandshakeTimeout, newT.TLSHandshakeTimeout, "when checking .TLSHandshakeTimeout")
|
||||
assert.Equal(t, old.DisableKeepAlives, newT.DisableKeepAlives, "when checking .DisableKeepAlives")
|
||||
assert.Equal(t, old.DisableCompression, newT.DisableCompression, "when checking .DisableCompression")
|
||||
assert.Equal(t, old.MaxIdleConns, newT.MaxIdleConns, "when checking .MaxIdleConns")
|
||||
assert.Equal(t, old.MaxIdleConnsPerHost, newT.MaxIdleConnsPerHost, "when checking .MaxIdleConnsPerHost")
|
||||
assert.Equal(t, old.IdleConnTimeout, newT.IdleConnTimeout, "when checking .IdleConnTimeout")
|
||||
assert.Equal(t, old.ResponseHeaderTimeout, newT.ResponseHeaderTimeout, "when checking .ResponseHeaderTimeout")
|
||||
assert.Equal(t, old.ExpectContinueTimeout, newT.ExpectContinueTimeout, "when checking .ExpectContinueTimeout")
|
||||
assert.Equal(t, old.TLSNextProto, newT.TLSNextProto, "when checking .TLSNextProto")
|
||||
assert.Equal(t, old.MaxResponseHeaderBytes, newT.MaxResponseHeaderBytes, "when checking .MaxResponseHeaderBytes")
|
||||
}
|
||||
|
||||
type aType struct {
|
||||
Matching string
|
||||
OnlyA string
|
||||
MatchingInt int
|
||||
DifferentType string
|
||||
}
|
||||
|
||||
type bType struct {
|
||||
Matching string
|
||||
OnlyB string
|
||||
MatchingInt int
|
||||
DifferentType int
|
||||
Unused string
|
||||
}
|
||||
|
||||
func TestSetFrom(t *testing.T) {
|
||||
a := aType{
|
||||
Matching: "a",
|
||||
OnlyA: "onlyA",
|
||||
MatchingInt: 1,
|
||||
DifferentType: "suprise",
|
||||
}
|
||||
|
||||
b := bType{
|
||||
Matching: "b",
|
||||
OnlyB: "onlyB",
|
||||
MatchingInt: 2,
|
||||
DifferentType: 7,
|
||||
Unused: "Ha",
|
||||
}
|
||||
bBefore := b
|
||||
|
||||
SetFrom(&a, &b)
|
||||
|
||||
assert.Equal(t, aType{
|
||||
Matching: "b",
|
||||
OnlyA: "onlyA",
|
||||
MatchingInt: 2,
|
||||
DifferentType: "suprise",
|
||||
}, a)
|
||||
|
||||
assert.Equal(t, bBefore, b)
|
||||
}
|
||||
|
||||
func TestSetFromReversed(t *testing.T) {
|
||||
a := aType{
|
||||
Matching: "a",
|
||||
OnlyA: "onlyA",
|
||||
MatchingInt: 1,
|
||||
DifferentType: "suprise",
|
||||
}
|
||||
aBefore := a
|
||||
|
||||
b := bType{
|
||||
Matching: "b",
|
||||
OnlyB: "onlyB",
|
||||
MatchingInt: 2,
|
||||
DifferentType: 7,
|
||||
Unused: "Ha",
|
||||
}
|
||||
|
||||
SetFrom(&b, &a)
|
||||
|
||||
assert.Equal(t, bType{
|
||||
Matching: "a",
|
||||
OnlyB: "onlyB",
|
||||
MatchingInt: 1,
|
||||
DifferentType: 7,
|
||||
Unused: "Ha",
|
||||
}, b)
|
||||
|
||||
assert.Equal(t, aBefore, a)
|
||||
}
|
Loading…
Reference in a new issue