diff --git a/Gopkg.lock b/Gopkg.lock index 5efa5285..ae4e10d5 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -381,10 +381,19 @@ version = "v1.0.6" [[projects]] - digest = "1:b31059dac028ff111793a8345eacf0f99d0f1150ead34ebf32fdd2b7d54c2d45" + digest = "1:60a46e2410edbf02b419f833372dd1d24d7aa1b916a990a7370e792fada1eadd" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "NUT" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:f1e5a94fc8fde9a67a97106d7d7d386ad0b938b41fde2b110e70447590aea6e7" name = "github.com/stretchr/testify" packages = [ "assert", + "mock", "require", "suite", ] @@ -609,6 +618,7 @@ "github.com/sacloud/libsacloud/api", "github.com/sacloud/libsacloud/sacloud", "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", "github.com/stretchr/testify/require", "github.com/stretchr/testify/suite", "github.com/timewasted/linode", diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index f803d0a8..d9c252e7 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/miekg/dns" @@ -18,8 +19,9 @@ type preCheckDNSFunc func(fqdn, value string) (bool, error) var ( // PreCheckDNS checks DNS propagation before notifying ACME that // the DNS challenge is ready. - PreCheckDNS preCheckDNSFunc = checkDNSPropagation - fqdnToZone = map[string]string{} + PreCheckDNS preCheckDNSFunc = checkDNSPropagation + fqdnToZone = map[string]string{} + muFqdnToZone sync.Mutex ) const defaultResolvConf = "/etc/resolv.conf" @@ -262,6 +264,9 @@ func lookupNameservers(fqdn string) ([]string, error) { // FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the // domain labels until the nameserver returns a SOA record in the answer section. func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { + muFqdnToZone.Lock() + defer muFqdnToZone.Unlock() + // Do we have it cached? if zone, ok := fqdnToZone[fqdn]; ok { return zone, nil diff --git a/platform/tester/env.go b/platform/tester/env.go new file mode 100644 index 00000000..b42da278 --- /dev/null +++ b/platform/tester/env.go @@ -0,0 +1,142 @@ +package tester + +import ( + "fmt" + "os" +) + +// EnvTest Environment variables manager for tests. +type EnvTest struct { + keys []string + values map[string]string + + liveTestHook func() bool + liveTestExtraHook func() bool + + domain string + domainKey string +} + +// NewEnvTest Creates an EnvTest. +func NewEnvTest(keys ...string) *EnvTest { + values := make(map[string]string) + for _, key := range keys { + value := os.Getenv(key) + if value != "" { + values[key] = value + } + } + + return &EnvTest{ + keys: keys, + values: values, + } +} + +// WithDomain Defines the name of the environment variable used to define the domain related to the DNS request. +// If the domain is defined, it was considered mandatory to define a test as a "live" test. +func (e *EnvTest) WithDomain(key string) *EnvTest { + e.domainKey = key + e.domain = os.Getenv(key) + return e +} + +// WithLiveTestRequirements Defines the environment variables required to define a test as a "live" test. +// Replaces the default behavior (all keys are required). +func (e *EnvTest) WithLiveTestRequirements(keys ...string) *EnvTest { + var countValuedVars int + for _, key := range keys { + if _, ok := e.values[key]; ok { + countValuedVars++ + } + } + + live := countValuedVars != 0 && len(keys) == countValuedVars + + e.liveTestHook = func() bool { + return live + } + + return e +} + +// WithLiveTestExtra Allows to define an additional condition to flag a test as "live" test. +// This does not replace the default behavior. +func (e *EnvTest) WithLiveTestExtra(extra func() bool) *EnvTest { + e.liveTestExtraHook = extra + return e +} + +// GetDomain Gets the domain value associated with the DNS challenge (linked to WithDomain method). +func (e *EnvTest) GetDomain() string { + return e.domain +} + +// IsLiveTest Checks whether environment variables allow running a "live" test. +func (e *EnvTest) IsLiveTest() bool { + liveTest := e.liveTestExtra() + + if e.liveTestHook != nil { + return liveTest && e.liveTestHook() + } + + liveTest = liveTest && len(e.values) == len(e.keys) + + if liveTest && len(e.domainKey) > 0 && len(e.domain) == 0 { + return false + } + + return liveTest +} + +// RestoreEnv Restores the environment variables to the initial state. +func (e *EnvTest) RestoreEnv() { + for key, value := range e.values { + os.Setenv(key, value) + } +} + +// ClearEnv Deletes all environment variables related to the test. +func (e *EnvTest) ClearEnv() { + for _, key := range e.keys { + os.Unsetenv(key) + } +} + +// GetValue Gets the stored value of an environment variable. +func (e *EnvTest) GetValue(key string) string { + return e.values[key] +} + +func (e *EnvTest) liveTestExtra() bool { + if e.liveTestExtraHook == nil { + return true + } + + return e.liveTestExtraHook() +} + +// Apply Sets/Unsets environment variables. +// Not related to the main environment variables. +func (e *EnvTest) Apply(envVars map[string]string) { + for key, value := range envVars { + if e.isManagedKey(key) { + if len(value) == 0 { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + } else { + panic(fmt.Sprintf("Unauthorized action, the env var %s is not managed.", key)) + } + } +} + +func (e *EnvTest) isManagedKey(varName string) bool { + for _, key := range e.keys { + if key == varName { + return true + } + } + return false +} diff --git a/platform/tester/env_test.go b/platform/tester/env_test.go new file mode 100644 index 00000000..a17f0ab1 --- /dev/null +++ b/platform/tester/env_test.go @@ -0,0 +1,347 @@ +package tester_test + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xenolf/lego/platform/tester" +) + +var ( + envNamespace = "LEGO_TEST_" + envVar01 = envNamespace + "01" + envVar02 = envNamespace + "02" + envVarDomain = envNamespace + "DOMAIN" +) + +func TestMain(m *testing.M) { + exitCode := m.Run() + clearEnv() + os.Exit(exitCode) +} + +func applyEnv(envVars map[string]string) { + for key, value := range envVars { + if len(value) == 0 { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + } +} + +func clearEnv() { + environ := os.Environ() + for _, key := range environ { + if strings.HasPrefix(key, envNamespace) { + os.Unsetenv(strings.Split(key, "=")[0]) + } + } + os.Unsetenv("EXTRA_LEGO_TEST") +} + +func TestEnvTest(t *testing.T) { + testCases := []struct { + desc string + envVars map[string]string + envTestSetup func() *tester.EnvTest + expected func(t *testing.T, envTest *tester.EnvTest) + }{ + { + desc: "simple", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "missing env var", + envVars: map[string]string{ + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithDomain", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + envVarDomain: "D", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02).WithDomain(envVarDomain) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetValue(envVarDomain)) + assert.Equal(t, "D", envTest.GetDomain()) + }, + }, + { + desc: "WithDomain missing env var", + envVars: map[string]string{ + envVar01: "A", + envVarDomain: "D", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02).WithDomain(envVarDomain) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetValue(envVarDomain)) + assert.Equal(t, "D", envTest.GetDomain()) + }, + }, + { + desc: "WithDomain missing domain", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02).WithDomain(envVarDomain) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetValue(envVarDomain)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02).WithLiveTestRequirements(envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements non required var missing", + envVars: map[string]string{ + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02).WithLiveTestRequirements(envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements required var missing", + envVars: map[string]string{ + envVar01: "A", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02).WithLiveTestRequirements(envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements WithDomain", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + envVarDomain: "D", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithDomain(envVarDomain). + WithLiveTestRequirements(envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetValue(envVarDomain)) + assert.Equal(t, "D", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements WithDomain without domain", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithDomain(envVarDomain). + WithLiveTestRequirements(envVar02) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetValue(envVarDomain)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestExtra true", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithLiveTestExtra(func() bool { return true }) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestExtra false", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithLiveTestExtra(func() bool { return false }) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements WithLiveTestExtra true", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithLiveTestRequirements(envVar02). + WithLiveTestExtra(func() bool { return true }) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.True(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements WithLiveTestExtra false", + envVars: map[string]string{ + envVar01: "A", + envVar02: "B", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithLiveTestRequirements(envVar02). + WithLiveTestExtra(func() bool { return false }) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "B", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + { + desc: "WithLiveTestRequirements require env var missing WithLiveTestExtra true", + envVars: map[string]string{ + envVar01: "A", + }, + envTestSetup: func() *tester.EnvTest { + return tester.NewEnvTest(envVar01, envVar02). + WithLiveTestRequirements(envVar02). + WithLiveTestExtra(func() bool { return true }) + }, + expected: func(t *testing.T, envTest *tester.EnvTest) { + assert.False(t, envTest.IsLiveTest()) + assert.Equal(t, "A", envTest.GetValue(envVar01)) + assert.Equal(t, "", envTest.GetValue(envVar02)) + assert.Equal(t, "", envTest.GetDomain()) + }, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + defer clearEnv() + applyEnv(test.envVars) + + envTest := test.envTestSetup() + + test.expected(t, envTest) + }) + } +} + +func TestEnvTest_RestoreEnv(t *testing.T) { + os.Setenv(envVar01, "A") + os.Setenv(envVar02, "B") + + envTest := tester.NewEnvTest(envVar01, envVar02) + + clearEnv() + + envTest.RestoreEnv() + + assert.Equal(t, "A", os.Getenv(envVar01)) + assert.Equal(t, "B", os.Getenv(envVar02)) +} + +func TestEnvTest_ClearEnv(t *testing.T) { + os.Setenv(envVar01, "A") + os.Setenv(envVar02, "B") + os.Setenv("EXTRA_LEGO_TEST", "X") + + envTest := tester.NewEnvTest(envVar01, envVar02) + + envTest.ClearEnv() + + assert.Equal(t, "", os.Getenv(envVar01)) + assert.Equal(t, "", os.Getenv(envVar02)) + assert.Equal(t, "X", os.Getenv("EXTRA_LEGO_TEST")) +} diff --git a/providers/dns/acmedns/acmedns_test.go b/providers/dns/acmedns/acmedns_test.go index d1131b10..398dea95 100644 --- a/providers/dns/acmedns/acmedns_test.go +++ b/providers/dns/acmedns/acmedns_test.go @@ -4,10 +4,9 @@ import ( "errors" "testing" + "github.com/cpu/goacmedns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/cpu/goacmedns" ) var ( diff --git a/providers/dns/alidns/alidns_test.go b/providers/dns/alidns/alidns_test.go index d809cdbf..f0cb2e30 100644 --- a/providers/dns/alidns/alidns_test.go +++ b/providers/dns/alidns/alidns_test.go @@ -1,34 +1,17 @@ package alidns import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestSecretKey string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("ALICLOUD_ACCESS_KEY") - envTestSecretKey = os.Getenv("ALICLOUD_SECRET_KEY") - envTestDomain = os.Getenv("ALIDNS_DOMAIN") - - if len(envTestAPIKey) > 0 && len(envTestSecretKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("ALICLOUD_ACCESS_KEY", envTestAPIKey) - os.Setenv("ALICLOUD_SECRET_KEY", envTestSecretKey) -} +var envTest = tester.NewEnvTest( + "ALICLOUD_ACCESS_KEY", + "ALICLOUD_SECRET_KEY"). + WithDomain("ALICLOUD_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -124,10 +103,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("ALICLOUD_ACCESS_KEY") - os.Unsetenv("ALICLOUD_SECRET_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey config.SecretKey = test.secretKey @@ -147,29 +122,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/auroradns/auroradns_test.go b/providers/dns/auroradns/auroradns_test.go index 1d9c8152..aa3c9817 100644 --- a/providers/dns/auroradns/auroradns_test.go +++ b/providers/dns/auroradns/auroradns_test.go @@ -5,28 +5,18 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - envTestUserID string - envTestKey string +var envTest = tester.NewEnvTest( + "AURORA_USER_ID", + "AURORA_KEY", ) -func init() { - envTestUserID = os.Getenv("AURORA_USER_ID") - envTestKey = os.Getenv("AURORA_KEY") -} - -func restoreEnv() { - os.Setenv("AURORA_USER_ID", envTestUserID) - os.Setenv("AURORA_KEY", envTestKey) -} - func setupTest() (*DNSProvider, *http.ServeMux, func()) { handler := http.NewServeMux() server := httptest.NewServer(handler) @@ -85,14 +75,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -142,10 +128,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("AURORA_USER_ID") - os.Unsetenv("AURORA_KEY") - config := NewDefaultConfig() config.UserID = test.userID config.Key = test.key diff --git a/providers/dns/azure/azure_test.go b/providers/dns/azure/azure_test.go index 1019ecc2..07a43f5e 100644 --- a/providers/dns/azure/azure_test.go +++ b/providers/dns/azure/azure_test.go @@ -1,43 +1,20 @@ package azure import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestClientID string - envTestClientSecret string - envTestSubscriptionID string - envTestTenantID string - envTestResourceGroup string - envTestDomain string -) - -func init() { - envTestClientID = os.Getenv("AZURE_CLIENT_ID") - envTestClientSecret = os.Getenv("AZURE_CLIENT_SECRET") - envTestSubscriptionID = os.Getenv("AZURE_SUBSCRIPTION_ID") - envTestTenantID = os.Getenv("AZURE_TENANT_ID") - envTestResourceGroup = os.Getenv("AZURE_RESOURCE_GROUP") - envTestDomain = os.Getenv("AZURE_DOMAIN") - - if len(envTestClientID) > 0 && len(envTestClientSecret) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("AZURE_CLIENT_ID", envTestClientID) - os.Setenv("AZURE_CLIENT_SECRET", envTestClientSecret) - os.Setenv("AZURE_SUBSCRIPTION_ID", envTestSubscriptionID) - os.Setenv("AZURE_TENANT_ID", envTestTenantID) - os.Setenv("AZURE_RESOURCE_GROUP", envTestResourceGroup) -} +var envTest = tester.NewEnvTest( + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", + "AZURE_SUBSCRIPTION_ID", + "AZURE_TENANT_ID", + "AZURE_RESOURCE_GROUP"). + WithDomain("AZURE_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -125,14 +102,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -218,13 +191,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("AZURE_CLIENT_ID") - os.Unsetenv("AZURE_CLIENT_SECRET") - os.Unsetenv("AZURE_SUBSCRIPTION_ID") - os.Unsetenv("AZURE_TENANT_ID") - os.Unsetenv("AZURE_RESOURCE_GROUP") - config := NewDefaultConfig() config.ClientID = test.clientID config.ClientSecret = test.clientSecret @@ -246,29 +212,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/bluecat/bluecat_test.go b/providers/dns/bluecat/bluecat_test.go index 15649a8d..da46d0ea 100644 --- a/providers/dns/bluecat/bluecat_test.go +++ b/providers/dns/bluecat/bluecat_test.go @@ -1,48 +1,20 @@ package bluecat import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestServer string - envTestUserName string - envTestPassword string - envTestConfigName string - envTestDNSView string - envTestDomain string -) - -func init() { - envTestServer = os.Getenv("BLUECAT_SERVER_URL") - envTestUserName = os.Getenv("BLUECAT_USER_NAME") - envTestPassword = os.Getenv("BLUECAT_PASSWORD") - envTestDomain = os.Getenv("BLUECAT_DOMAIN") - envTestConfigName = os.Getenv("BLUECAT_CONFIG_NAME") - envTestDNSView = os.Getenv("BLUECAT_DNS_VIEW") - - if len(envTestServer) > 0 && - len(envTestDomain) > 0 && - len(envTestUserName) > 0 && - len(envTestPassword) > 0 && - len(envTestConfigName) > 0 && - len(envTestDNSView) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("BLUECAT_SERVER_URL", envTestServer) - os.Setenv("BLUECAT_USER_NAME", envTestUserName) - os.Setenv("BLUECAT_PASSWORD", envTestPassword) - os.Setenv("BLUECAT_CONFIG_NAME", envTestConfigName) - os.Setenv("BLUECAT_DNS_VIEW", envTestDNSView) -} +var envTest = tester.NewEnvTest( + "BLUECAT_SERVER_URL", + "BLUECAT_USER_NAME", + "BLUECAT_PASSWORD", + "BLUECAT_CONFIG_NAME", + "BLUECAT_DNS_VIEW"). + WithDomain("BLUECAT_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -130,14 +102,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -223,13 +191,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("BLUECAT_SERVER_URL") - os.Unsetenv("BLUECAT_USER_NAME") - os.Unsetenv("BLUECAT_PASSWORD") - os.Unsetenv("BLUECAT_CONFIG_NAME") - os.Unsetenv("BLUECAT_DNS_VIEW") - config := NewDefaultConfig() config.BaseURL = test.baseURL config.UserName = test.userName @@ -251,29 +212,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(time.Second * 1) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/cloudflare/cloudflare_test.go b/providers/dns/cloudflare/cloudflare_test.go index 5605cd03..4c7a0e66 100644 --- a/providers/dns/cloudflare/cloudflare_test.go +++ b/providers/dns/cloudflare/cloudflare_test.go @@ -1,34 +1,18 @@ package cloudflare import ( - "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestEmail string - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestEmail = os.Getenv("CLOUDFLARE_EMAIL") - envTestAPIKey = os.Getenv("CLOUDFLARE_API_KEY") - envTestDomain = os.Getenv("CLOUDFLARE_DOMAIN") - if len(envTestEmail) > 0 && len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("CLOUDFLARE_EMAIL", envTestEmail) - os.Setenv("CLOUDFLARE_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest( + "CLOUDFLARE_EMAIL", + "CLOUDFLARE_API_KEY"). + WithDomain("CLOUDFLARE_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +55,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -124,10 +104,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("CLOUDFLARE_EMAIL") - os.Unsetenv("CLOUDFLARE_API_KEY") - config := NewDefaultConfig() config.AuthEmail = test.authEmail config.AuthKey = test.authKey @@ -146,36 +122,30 @@ func TestNewDNSProviderConfig(t *testing.T) { } } -func TestPresent(t *testing.T) { - if !liveTest { +func TestLivePresent(t *testing.T) { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - config := NewDefaultConfig() - config.AuthEmail = envTestEmail - config.AuthKey = envTestAPIKey - - provider, err := NewDNSProviderConfig(config) + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } -func TestCleanUp(t *testing.T) { - if !liveTest { +func TestLiveCleanUp(t *testing.T) { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - time.Sleep(time.Second * 2) - - config := NewDefaultConfig() - config.AuthEmail = envTestEmail - config.AuthKey = envTestAPIKey - - provider, err := NewDNSProviderConfig(config) + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + time.Sleep(2 * time.Second) + + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/cloudxns/cloudxns_test.go b/providers/dns/cloudxns/cloudxns_test.go index 83c5d32f..16527cf1 100644 --- a/providers/dns/cloudxns/cloudxns_test.go +++ b/providers/dns/cloudxns/cloudxns_test.go @@ -1,34 +1,17 @@ package cloudxns import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestSecretKey string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("CLOUDXNS_API_KEY") - envTestSecretKey = os.Getenv("CLOUDXNS_SECRET_KEY") - envTestDomain = os.Getenv("CLOUDXNS_DOMAIN") - - if len(envTestAPIKey) > 0 && len(envTestSecretKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("CLOUDXNS_API_KEY", envTestAPIKey) - os.Setenv("CLOUDXNS_SECRET_KEY", envTestSecretKey) -} +var envTest = tester.NewEnvTest( + "CLOUDXNS_API_KEY", + "CLOUDXNS_SECRET_KEY"). + WithDomain("CLOUDXNS_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -124,10 +103,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("CLOUDXNS_API_KEY") - os.Unsetenv("CLOUDXNS_SECRET_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey config.SecretKey = test.secretKey @@ -146,28 +121,30 @@ func TestNewDNSProviderConfig(t *testing.T) { } } -func TestPresent(t *testing.T) { - if !liveTest { +func TestLivePresent(t *testing.T) { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - provider, err := NewDNSProviderCredentials(envTestAPIKey, envTestSecretKey) + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } -func TestCleanUp(t *testing.T) { - if !liveTest { +func TestLiveCleanUp(t *testing.T) { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - time.Sleep(time.Second * 2) - - provider, err := NewDNSProviderCredentials(envTestAPIKey, envTestSecretKey) + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + time.Sleep(2 * time.Second) + + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/digitalocean/digitalocean_test.go b/providers/dns/digitalocean/digitalocean_test.go index 89005de1..73733415 100644 --- a/providers/dns/digitalocean/digitalocean_test.go +++ b/providers/dns/digitalocean/digitalocean_test.go @@ -5,24 +5,14 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - envTestAuthToken string -) - -func init() { - envTestAuthToken = os.Getenv("DO_AUTH_TOKEN") -} - -func restoreEnv() { - os.Setenv("DO_AUTH_TOKEN", envTestAuthToken) -} +var envTest = tester.NewEnvTest("DO_AUTH_TOKEN") func setupTest() (*DNSProvider, *http.ServeMux, func()) { handler := http.NewServeMux() @@ -63,14 +53,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -104,10 +90,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("DO_AUTH_TOKEN") - os.Unsetenv("ALICLOUD_SECRET_KEY") - config := NewDefaultConfig() config.AuthToken = test.authToken diff --git a/providers/dns/dnsimple/dnsimple_test.go b/providers/dns/dnsimple/dnsimple_test.go index e5adaac2..08d502f5 100644 --- a/providers/dns/dnsimple/dnsimple_test.go +++ b/providers/dns/dnsimple/dnsimple_test.go @@ -8,35 +8,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xenolf/lego/acme" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestOauthToken string - envTestDomain string - envTestBaseURL string -) +const sandboxURL = "https://api.sandbox.fake.com" -func init() { - envTestOauthToken = os.Getenv("DNSIMPLE_OAUTH_TOKEN") - envTestDomain = os.Getenv("DNSIMPLE_DOMAIN") - envTestBaseURL = "https://api.sandbox.fake.com" - - if len(envTestOauthToken) > 0 && len(envTestDomain) > 0 { - baseURL := os.Getenv("DNSIMPLE_BASE_URL") - - if baseURL != "" { - envTestBaseURL = baseURL - } - - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("DNSIMPLE_OAUTH_TOKEN", envTestOauthToken) - os.Setenv("DNSIMPLE_BASE_URL", envTestBaseURL) -} +var envTest = tester.NewEnvTest( + "DNSIMPLE_OAUTH_TOKEN", + "DNSIMPLE_BASE_URL"). + WithDomain("DNSIMPLE_DOMAIN"). + WithLiveTestRequirements("DNSIMPLE_OAUTH_TOKEN", "DNSIMPLE_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -70,14 +51,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) if test.userAgent != "" { acme.UserAgent = test.userAgent @@ -132,10 +109,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("DNSIMPLE_OAUTH_TOKEN") - os.Unsetenv("DNSIMPLE_BASE_URL") - config := NewDefaultConfig() config.AccessToken = test.accessToken config.BaseURL = test.baseURL @@ -160,29 +133,39 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() + + if len(os.Getenv("DNSIMPLE_BASE_URL")) == 0 { + os.Setenv("DNSIMPLE_BASE_URL", sandboxURL) + } + provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() + + if len(os.Getenv("DNSIMPLE_BASE_URL")) == 0 { + os.Setenv("DNSIMPLE_BASE_URL", sandboxURL) + } + provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/dnsmadeeasy/dnsmadeeasy_test.go b/providers/dns/dnsmadeeasy/dnsmadeeasy_test.go index 5217a83c..0bc0dbe3 100644 --- a/providers/dns/dnsmadeeasy/dnsmadeeasy_test.go +++ b/providers/dns/dnsmadeeasy/dnsmadeeasy_test.go @@ -5,28 +5,16 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestAPISecret string - envTestDomain string -) +var envTest = tester.NewEnvTest( + "DNSMADEEASY_API_KEY", + "DNSMADEEASY_API_SECRET"). + WithDomain("DNSMADEEASY_DOMAIN") func init() { - envTestAPIKey = os.Getenv("DNSMADEEASY_API_KEY") - envTestAPISecret = os.Getenv("DNSMADEEASY_API_SECRET") - envTestDomain = os.Getenv("DNSMADEEASY_DOMAIN") - os.Setenv("DNSMADEEASY_SANDBOX", "true") - - liveTest = len(envTestAPIKey) > 0 && len(envTestAPISecret) > 0 -} - -func restoreEnv() { - os.Setenv("DNSMADEEASY_API_KEY", envTestAPIKey) - os.Setenv("DNSMADEEASY_API_SECRET", envTestAPISecret) } func TestNewDNSProvider(t *testing.T) { @@ -70,14 +58,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -123,10 +107,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("DNSMADEEASY_API_KEY") - os.Unsetenv("DNSMADEEASY_API_SECRET") - config := NewDefaultConfig() config.APIKey = test.apiKey config.APISecret = test.apiSecret @@ -146,17 +126,17 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresentAndCleanup(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/dnspod/dnspod_test.go b/providers/dns/dnspod/dnspod_test.go index 8108ac21..5d235aaf 100644 --- a/providers/dns/dnspod/dnspod_test.go +++ b/providers/dns/dnspod/dnspod_test.go @@ -1,31 +1,15 @@ package dnspod import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("DNSPOD_API_KEY") - envTestDomain = os.Getenv("DNSPOD_DOMAIN") - - if len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("DNSPOD_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest("DNSPOD_API_KEY"). + WithDomain("DNSPOD_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -50,14 +34,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -91,9 +71,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("DNSPOD_API_KEY") - config := NewDefaultConfig() config.LoginToken = test.loginToken @@ -112,29 +89,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/dreamhost/dreamhost_test.go b/providers/dns/dreamhost/dreamhost_test.go index f98354fb..fc22fe3d 100644 --- a/providers/dns/dreamhost/dreamhost_test.go +++ b/providers/dns/dreamhost/dreamhost_test.go @@ -4,37 +4,23 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestDomain string - envTestAPIKey string +var envTest = tester.NewEnvTest("DREAMHOST_API_KEY"). + WithDomain("DREAMHOST_TEST_DOMAIN") +var ( fakeAPIKey = "asdf1234" fakeChallengeToken = "foobar" fakeKeyAuth = "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI" ) -func init() { - envTestAPIKey = os.Getenv("DREAMHOST_API_KEY") - envTestDomain = os.Getenv("DREAMHOST_TEST_DOMAIN") - - if len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("DREAMHOST_API_KEY", envTestAPIKey) -} - func setupTest() (*DNSProvider, *http.ServeMux, func()) { handler := http.NewServeMux() server := httptest.NewServer(handler) @@ -74,14 +60,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -196,19 +178,19 @@ func TestDNSProvider_Cleanup(t *testing.T) { } func TestLivePresentAndCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/duckdns/duckdns_test.go b/providers/dns/duckdns/duckdns_test.go index d2aac729..25c2c686 100644 --- a/providers/dns/duckdns/duckdns_test.go +++ b/providers/dns/duckdns/duckdns_test.go @@ -1,30 +1,15 @@ package duckdns import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestToken string - envTestDomain string -) - -func init() { - envTestToken = os.Getenv("DUCKDNS_TOKEN") - envTestDomain = os.Getenv("DUCKDNS_DOMAIN") - if len(envTestToken) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("DUCKDNS_TOKEN", envTestToken) -} +var envTest = tester.NewEnvTest("DUCKDNS_TOKEN"). + WithDomain("DUCKDNS_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -49,14 +34,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -89,9 +70,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("DUCKDNS_TOKEN") - config := NewDefaultConfig() config.Token = test.token @@ -109,29 +87,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(10 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/dyn/dyn_test.go b/providers/dns/dyn/dyn_test.go index 47cf8f83..eebf0fd4 100644 --- a/providers/dns/dyn/dyn_test.go +++ b/providers/dns/dyn/dyn_test.go @@ -1,37 +1,18 @@ package dyn import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestCustomerName string - envTestUserName string - envTestPassword string - envTestDomain string -) - -func init() { - envTestCustomerName = os.Getenv("DYN_CUSTOMER_NAME") - envTestUserName = os.Getenv("DYN_USER_NAME") - envTestPassword = os.Getenv("DYN_PASSWORD") - envTestDomain = os.Getenv("DYN_DOMAIN") - - if len(envTestCustomerName) > 0 && len(envTestUserName) > 0 && len(envTestPassword) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("DYN_CUSTOMER_NAME", envTestCustomerName) - os.Setenv("DYN_USER_NAME", envTestUserName) - os.Setenv("DYN_PASSWORD", envTestPassword) -} +var envTest = tester.NewEnvTest( + "DYN_CUSTOMER_NAME", + "DYN_USER_NAME", + "DYN_PASSWORD"). + WithDomain("DYN_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -87,14 +68,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -152,9 +129,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("DNSPOD_API_KEY") - config := NewDefaultConfig() config.CustomerName = test.customerName config.Password = test.password @@ -174,29 +148,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/exec/exec_test.go b/providers/dns/exec/exec_test.go new file mode 100644 index 00000000..c7fe4c02 --- /dev/null +++ b/providers/dns/exec/exec_test.go @@ -0,0 +1,158 @@ +package exec + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/xenolf/lego/log" +) + +func TestDNSProvider_Present(t *testing.T) { + backupLogger := log.Logger + defer func() { + log.Logger = backupLogger + }() + + logRecorder := &LogRecorder{} + log.Logger = logRecorder + + type expected struct { + args string + error bool + } + + testCases := []struct { + desc string + config *Config + expected expected + }{ + { + desc: "Standard mode", + config: &Config{ + Program: "echo", + Mode: "", + }, + expected: expected{ + args: "present _acme-challenge.domain. pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM 120\n", + }, + }, + { + desc: "program error", + config: &Config{ + Program: "ogellego", + Mode: "", + }, + expected: expected{error: true}, + }, + { + desc: "Raw mode", + config: &Config{ + Program: "echo", + Mode: "RAW", + }, + expected: expected{ + args: "present -- domain token keyAuth\n", + }, + }, + } + + var message string + logRecorder.On("Println", mock.Anything).Run(func(args mock.Arguments) { + message = args.String(0) + fmt.Fprintln(os.Stdout, "XXX", message) + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + message = "" + + provider, err := NewDNSProviderConfig(test.config) + require.NoError(t, err) + + err = provider.Present("domain", "token", "keyAuth") + if test.expected.error { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, test.expected.args, message) + } + }) + } +} + +func TestDNSProvider_CleanUp(t *testing.T) { + backupLogger := log.Logger + defer func() { + log.Logger = backupLogger + }() + + logRecorder := &LogRecorder{} + log.Logger = logRecorder + + type expected struct { + args string + error bool + } + + testCases := []struct { + desc string + config *Config + expected expected + }{ + { + desc: "Standard mode", + config: &Config{ + Program: "echo", + Mode: "", + }, + expected: expected{ + args: "cleanup _acme-challenge.domain. pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM 120\n", + }, + }, + { + desc: "program error", + config: &Config{ + Program: "ogellego", + Mode: "", + }, + expected: expected{error: true}, + }, + { + desc: "Raw mode", + config: &Config{ + Program: "echo", + Mode: "RAW", + }, + expected: expected{ + args: "cleanup -- domain token keyAuth\n", + }, + }, + } + + var message string + logRecorder.On("Println", mock.Anything).Run(func(args mock.Arguments) { + message = args.String(0) + fmt.Fprintln(os.Stdout, "XXX", message) + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + message = "" + + provider, err := NewDNSProviderConfig(test.config) + require.NoError(t, err) + + err = provider.CleanUp("domain", "token", "keyAuth") + if test.expected.error { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, test.expected.args, message) + } + }) + } +} diff --git a/providers/dns/exec/log_mock_test.go b/providers/dns/exec/log_mock_test.go new file mode 100644 index 00000000..47935cc5 --- /dev/null +++ b/providers/dns/exec/log_mock_test.go @@ -0,0 +1,31 @@ +package exec + +import "github.com/stretchr/testify/mock" + +type LogRecorder struct { + mock.Mock +} + +func (*LogRecorder) Fatal(args ...interface{}) { + panic("implement me") +} + +func (*LogRecorder) Fatalln(args ...interface{}) { + panic("implement me") +} + +func (*LogRecorder) Fatalf(format string, args ...interface{}) { + panic("implement me") +} + +func (*LogRecorder) Print(args ...interface{}) { + panic("implement me") +} + +func (l *LogRecorder) Println(args ...interface{}) { + l.Called(args...) +} + +func (*LogRecorder) Printf(format string, args ...interface{}) { + panic("implement me") +} diff --git a/providers/dns/exoscale/exoscale_test.go b/providers/dns/exoscale/exoscale_test.go index 84056e72..6335a516 100644 --- a/providers/dns/exoscale/exoscale_test.go +++ b/providers/dns/exoscale/exoscale_test.go @@ -1,35 +1,18 @@ package exoscale import ( - "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestAPISecret string - envTestDomain string -) - -func init() { - envTestAPISecret = os.Getenv("EXOSCALE_API_SECRET") - envTestAPIKey = os.Getenv("EXOSCALE_API_KEY") - envTestDomain = os.Getenv("EXOSCALE_DOMAIN") - - if len(envTestAPIKey) > 0 && len(envTestAPISecret) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("EXOSCALE_API_KEY", envTestAPIKey) - os.Setenv("EXOSCALE_API_SECRET", envTestAPISecret) -} +var envTest = tester.NewEnvTest( + "EXOSCALE_API_SECRET", + "EXOSCALE_API_KEY"). + WithDomain("EXOSCALE_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -72,14 +55,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -125,10 +104,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("EXOSCALE_API_KEY") - os.Unsetenv("EXOSCALE_API_SECRET") - config := NewDefaultConfig() config.APIKey = test.apiKey config.APISecret = test.apiSecret @@ -200,33 +175,33 @@ func TestDNSProvider_FindZoneAndRecordName(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) // Present Twice to handle create / update - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/fastdns/fastdns_test.go b/providers/dns/fastdns/fastdns_test.go index 30ff9aa0..2c04f965 100644 --- a/providers/dns/fastdns/fastdns_test.go +++ b/providers/dns/fastdns/fastdns_test.go @@ -1,41 +1,20 @@ package fastdns import ( - "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestHost string - envTestClientToken string - envTestClientSecret string - envTestAccessToken string - envTestDomain string -) - -func init() { - envTestHost = os.Getenv("AKAMAI_HOST") - envTestClientToken = os.Getenv("AKAMAI_CLIENT_TOKEN") - envTestClientSecret = os.Getenv("AKAMAI_CLIENT_SECRET") - envTestAccessToken = os.Getenv("AKAMAI_ACCESS_TOKEN") - envTestDomain = os.Getenv("AKAMAI_TEST_DOMAIN") - - if len(envTestHost) > 0 && len(envTestClientToken) > 0 && len(envTestClientSecret) > 0 && len(envTestAccessToken) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("AKAMAI_HOST", envTestHost) - os.Setenv("AKAMAI_CLIENT_TOKEN", envTestClientToken) - os.Setenv("AKAMAI_CLIENT_SECRET", envTestClientSecret) - os.Setenv("AKAMAI_ACCESS_TOKEN", envTestAccessToken) -} +var envTest = tester.NewEnvTest( + "AKAMAI_HOST", + "AKAMAI_CLIENT_TOKEN", + "AKAMAI_CLIENT_SECRET", + "AKAMAI_ACCESS_TOKEN"). + WithDomain("AKAMAI_TEST_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -106,14 +85,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -184,12 +159,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("AKAMAI_HOST") - os.Unsetenv("AKAMAI_CLIENT_TOKEN") - os.Unsetenv("AKAMAI_CLIENT_SECRET") - os.Unsetenv("AKAMAI_ACCESS_TOKEN") - config := NewDefaultConfig() config.ClientToken = test.clientToken config.ClientSecret = test.clientSecret @@ -264,38 +233,33 @@ func TestDNSProvider_findZoneAndRecordName(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) // Present Twice to handle create / update - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - time.Sleep(time.Second * 1) - - config := NewDefaultConfig() - config.Host = envTestHost - config.ClientToken = envTestClientToken - config.ClientSecret = envTestClientSecret - config.AccessToken = envTestAccessToken - - provider, err := NewDNSProviderConfig(config) + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + time.Sleep(1 * time.Second) + + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/gandi/gandi_test.go b/providers/dns/gandi/gandi_test.go index cf1779e8..69ca8350 100644 --- a/providers/dns/gandi/gandi_test.go +++ b/providers/dns/gandi/gandi_test.go @@ -5,25 +5,15 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "os" "regexp" "strings" "testing" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - envTestAPIKey string -) - -func init() { - envTestAPIKey = os.Getenv("GANDI_API_KEY") -} - -func restoreEnv() { - os.Setenv("GANDI_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest("GANDI_API_KEY") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -48,14 +38,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -90,9 +76,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("GANDI_API_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey diff --git a/providers/dns/gandiv5/gandiv5_test.go b/providers/dns/gandiv5/gandiv5_test.go index 693a0b82..cf0172aa 100644 --- a/providers/dns/gandiv5/gandiv5_test.go +++ b/providers/dns/gandiv5/gandiv5_test.go @@ -5,25 +5,15 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "os" "regexp" "testing" "github.com/stretchr/testify/require" "github.com/xenolf/lego/log" + "github.com/xenolf/lego/platform/tester" ) -var ( - envTestAPIKey string -) - -func init() { - envTestAPIKey = os.Getenv("GANDIV5_API_KEY") -} - -func restoreEnv() { - os.Setenv("GANDIV5_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest("GANDIV5_API_KEY") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -48,14 +38,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -89,9 +75,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("GANDIV5_API_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey diff --git a/providers/dns/gcloud/googlecloud_test.go b/providers/dns/gcloud/googlecloud_test.go index e374396a..f58c697c 100644 --- a/providers/dns/gcloud/googlecloud_test.go +++ b/providers/dns/gcloud/googlecloud_test.go @@ -1,41 +1,25 @@ package gcloud import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" "golang.org/x/net/context" "golang.org/x/oauth2/google" "google.golang.org/api/dns/v1" ) -var ( - liveTest bool - envTestProject string - envTestServiceAccountFile string - envTestGoogleApplicationCredentials string - envTestDomain string -) - -func init() { - envTestProject = os.Getenv("GCE_PROJECT") - envTestServiceAccountFile = os.Getenv("GCE_SERVICE_ACCOUNT_FILE") - envTestGoogleApplicationCredentials = os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") - envTestDomain = os.Getenv("GCE_DOMAIN") - - _, err := google.DefaultClient(context.Background(), dns.NdevClouddnsReadwriteScope) - if err == nil && len(envTestProject) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("GCE_PROJECT", envTestProject) - os.Setenv("GCE_SERVICE_ACCOUNT_FILE", envTestServiceAccountFile) - os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", envTestGoogleApplicationCredentials) -} +var envTest = tester.NewEnvTest( + "GCE_PROJECT", + "GCE_SERVICE_ACCOUNT_FILE", + "GOOGLE_APPLICATION_CREDENTIALS"). + WithDomain("GCE_DOMAIN"). + WithLiveTestExtra(func() bool { + _, err := google.DefaultClient(context.Background(), dns.NdevClouddnsReadwriteScope) + return err == nil + }) func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -72,15 +56,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() + envTest.ClearEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -115,9 +94,8 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("GCE_PROJECT") - os.Unsetenv("GCE_SERVICE_ACCOUNT_FILE") + defer envTest.RestoreEnv() + envTest.ClearEnv() config := NewDefaultConfig() config.Project = test.project @@ -137,43 +115,49 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - provider, err := NewDNSProviderCredentials(envTestProject) + envTest.RestoreEnv() + + provider, err := NewDNSProviderCredentials(envTest.GetValue("GCE_PROJECT")) require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLivePresentMultiple(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - provider, err := NewDNSProviderCredentials(envTestProject) + envTest.RestoreEnv() + + provider, err := NewDNSProviderCredentials(envTest.GetValue("GCE_PROJECT")) require.NoError(t, err) // Check that we're able to create multiple entries - err = provider.Present(envTestDomain, "1", "123d==") + err = provider.Present(envTest.GetDomain(), "1", "123d==") require.NoError(t, err) - err = provider.Present(envTestDomain, "2", "123d==") + err = provider.Present(envTest.GetDomain(), "2", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - provider, err := NewDNSProviderCredentials(envTestProject) + envTest.RestoreEnv() + + provider, err := NewDNSProviderCredentials(envTest.GetValue("GCE_PROJECT")) require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/glesys/glesys_test.go b/providers/dns/glesys/glesys_test.go index 7a0b9971..11afd479 100644 --- a/providers/dns/glesys/glesys_test.go +++ b/providers/dns/glesys/glesys_test.go @@ -1,32 +1,16 @@ package glesys import ( - "os" "testing" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIUser string - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestAPIUser = os.Getenv("GLESYS_API_USER") - envTestAPIKey = os.Getenv("GLESYS_API_KEY") - - if len(envTestAPIUser) > 0 && len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("GLESYS_API_USER", envTestAPIUser) - os.Setenv("GLESYS_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest( + "GLESYS_API_USER", + "GLESYS_API_KEY"). + WithDomain("GLESYS_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -69,14 +53,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -124,10 +104,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("GLESYS_API_USER") - os.Unsetenv("GLESYS_API_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey config.APIUser = test.apiUser @@ -147,27 +123,27 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/godaddy/godaddy_test.go b/providers/dns/godaddy/godaddy_test.go index d0c5bf2a..5f8f7e45 100644 --- a/providers/dns/godaddy/godaddy_test.go +++ b/providers/dns/godaddy/godaddy_test.go @@ -1,33 +1,16 @@ package godaddy import ( - "os" "testing" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestAPISecret string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("GODADDY_API_KEY") - envTestAPISecret = os.Getenv("GODADDY_API_SECRET") - envTestDomain = os.Getenv("GODADDY_DOMAIN") - - if len(envTestAPIKey) > 0 && len(envTestAPISecret) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("GODADDY_API_KEY", envTestAPIKey) - os.Setenv("GODADDY_API_SECRET", envTestAPISecret) -} +var envTest = tester.NewEnvTest( + "GODADDY_API_KEY", + "GODADDY_API_SECRET"). + WithDomain("GODADDY_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -70,14 +53,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -122,10 +101,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("GODADDY_API_KEY") - os.Unsetenv("GODADDY_API_SECRET") - config := NewDefaultConfig() config.APIKey = test.apiKey config.APISecret = test.apiSecret @@ -144,27 +119,27 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/hostingde/hostingde_test.go b/providers/dns/hostingde/hostingde_test.go index d1cd9662..5be0d535 100644 --- a/providers/dns/hostingde/hostingde_test.go +++ b/providers/dns/hostingde/hostingde_test.go @@ -1,33 +1,17 @@ package hostingde import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestZone string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("HOSTINGDE_API_KEY") - envTestZone = os.Getenv("HOSTINGDE_ZONE_NAME") - envTestDomain = os.Getenv("HOSTINGDE_DOMAIN") - if len(envTestZone) > 0 && len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("HOSTINGDE_ZONE_NAME", envTestZone) - os.Setenv("HOSTINGDE_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest( + "HOSTINGDE_API_KEY", + "HOSTINGDE_ZONE_NAME"). + WithDomain("HOSTINGDE_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -70,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -123,10 +103,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("HOSTINGDE_API_KEY") - os.Unsetenv("HOSTINGDE_ZONE_NAME") - config := NewDefaultConfig() config.APIKey = test.apiKey config.ZoneName = test.zoneName @@ -146,29 +122,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(2 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/iij/iij_test.go b/providers/dns/iij/iij_test.go index 76c3fd5b..07d730bb 100644 --- a/providers/dns/iij/iij_test.go +++ b/providers/dns/iij/iij_test.go @@ -1,38 +1,18 @@ package iij import ( - "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIAccessKey string - envTestAPISecretKeyEnv string - envTestDoServiceCodeEnv string - envTestDomain string -) - -func init() { - envTestAPIAccessKey = os.Getenv("IIJ_API_ACCESS_KEY") - envTestAPISecretKeyEnv = os.Getenv("IIJ_API_SECRET_KEY") - envTestDoServiceCodeEnv = os.Getenv("IIJ_DO_SERVICE_CODE") - - envTestDomain = os.Getenv("IIJ_API_TESTDOMAIN") - - if len(envTestAPIAccessKey) > 0 && len(envTestAPISecretKeyEnv) > 0 && len(envTestDoServiceCodeEnv) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("IIJ_API_ACCESS_KEY", envTestAPIAccessKey) - os.Setenv("IIJ_API_SECRET_KEY", envTestAPISecretKeyEnv) - os.Setenv("IIJ_DO_SERVICE_CODE", envTestDoServiceCodeEnv) -} +var envTest = tester.NewEnvTest( + "IIJ_API_ACCESS_KEY", + "IIJ_API_SECRET_KEY", + "IIJ_DO_SERVICE_CODE"). + WithDomain("IIJ_API_TESTDOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -88,14 +68,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -154,11 +130,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("IIJ_API_ACCESS_KEY") - os.Unsetenv("IIJ_API_SECRET_KEY") - os.Unsetenv("IIJ_DO_SERVICE_CODE") - config := NewDefaultConfig() config.AccessKey = test.accessKey config.SecretKey = test.secretKey @@ -232,27 +203,27 @@ func TestSplitDomain(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/lightsail/lightsail_integration_test.go b/providers/dns/lightsail/lightsail_integration_test.go index 729cc900..4eb79976 100644 --- a/providers/dns/lightsail/lightsail_integration_test.go +++ b/providers/dns/lightsail/lightsail_integration_test.go @@ -1,8 +1,6 @@ package lightsail import ( - "fmt" - "os" "testing" "github.com/aws/aws-sdk-go/aws" @@ -12,15 +10,16 @@ import ( ) func TestLiveTTL(t *testing.T) { - m, err := testGetAndPreCheck() - if err != nil { - t.Skip(err.Error()) + if !envTest.IsLiveTest() { + t.Skip("skipping live test") } + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - domain := m["lightsailDomain"] + domain := envTest.GetDomain() err = provider.Present(domain, "foo", "bar") require.NoError(t, err) @@ -50,24 +49,10 @@ func TestLiveTTL(t *testing.T) { entries := resp.Domain.DomainEntries for _, entry := range entries { - if *entry.Type == "TXT" && *entry.Name == fqdn { + if aws.StringValue(entry.Type) == "TXT" && aws.StringValue(entry.Name) == fqdn { return } } t.Fatalf("Could not find a TXT record for _acme-challenge.%s", domain) } - -func testGetAndPreCheck() (map[string]string, error) { - m := map[string]string{ - "lightsailKey": os.Getenv("AWS_ACCESS_KEY_ID"), - "lightsailSecret": os.Getenv("AWS_SECRET_ACCESS_KEY"), - "lightsailDomain": os.Getenv("DNS_ZONE"), - } - for _, v := range m { - if v == "" { - return nil, fmt.Errorf("AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and R53_DOMAIN are needed to run this test") - } - } - return m, nil -} diff --git a/providers/dns/lightsail/lightsail_test.go b/providers/dns/lightsail/lightsail_test.go index da53257c..2ce2fde9 100644 --- a/providers/dns/lightsail/lightsail_test.go +++ b/providers/dns/lightsail/lightsail_test.go @@ -10,25 +10,16 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/lightsail" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - envTestSecret string - envTestKey string - envTestZone string -) - -func init() { - envTestKey = os.Getenv("AWS_ACCESS_KEY_ID") - envTestSecret = os.Getenv("AWS_SECRET_ACCESS_KEY") -} - -func restoreEnv() { - os.Setenv("AWS_ACCESS_KEY_ID", envTestKey) - os.Setenv("AWS_SECRET_ACCESS_KEY", envTestSecret) - os.Setenv("AWS_REGION", "us-east-1") - os.Setenv("AWS_HOSTED_ZONE_ID", envTestZone) -} +var envTest = tester.NewEnvTest( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_HOSTED_ZONE_ID"). + WithDomain("DNS_ZONE"). + WithLiveTestRequirements("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "DNS_ZONE") func makeProvider(ts *httptest.Server) (*DNSProvider, error) { config := &aws.Config{ @@ -50,7 +41,9 @@ func makeProvider(ts *httptest.Server) (*DNSProvider, error) { } func TestCredentialsFromEnv(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() + envTest.ClearEnv() + os.Setenv("AWS_ACCESS_KEY_ID", "123") os.Setenv("AWS_SECRET_ACCESS_KEY", "123") os.Setenv("AWS_REGION", "us-east-1") diff --git a/providers/dns/linode/linode_test.go b/providers/dns/linode/linode_test.go index 1b9d9365..f8dbad5d 100644 --- a/providers/dns/linode/linode_test.go +++ b/providers/dns/linode/linode_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/timewasted/linode" "github.com/timewasted/linode/dns" + "github.com/xenolf/lego/platform/tester" ) type ( @@ -28,20 +29,7 @@ type ( MockResponseMap map[string]MockResponse ) -var ( - apiKey string - liveTest bool -) - -func init() { - apiKey = os.Getenv("LINODE_API_KEY") - - liveTest = len(apiKey) != 0 -} - -func restoreEnv() { - os.Setenv("LINODE_API_KEY", apiKey) -} +var envTest = tester.NewEnvTest("LINODE_API_KEY") func newMockServer(responses MockResponseMap) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -103,14 +91,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -144,9 +128,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("LINODE_API_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey @@ -165,7 +146,7 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestDNSProvider_Present(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() os.Setenv("LINODE_API_KEY", "testing") p, err := NewDNSProvider() @@ -252,7 +233,7 @@ func TestDNSProvider_Present(t *testing.T) { } func TestDNSProvider_CleanUp(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() os.Setenv("LINODE_API_KEY", "testing") p, err := NewDNSProvider() @@ -363,14 +344,14 @@ func TestDNSProvider_CleanUp(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("Skipping live test") } // TODO implement this test } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("Skipping live test") } // TODO implement this test diff --git a/providers/dns/linodev4/linodev4_test.go b/providers/dns/linodev4/linodev4_test.go index ab88de64..97e5591f 100644 --- a/providers/dns/linodev4/linodev4_test.go +++ b/providers/dns/linodev4/linodev4_test.go @@ -12,26 +12,14 @@ import ( "github.com/linode/linodego" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) type ( MockResponseMap map[string]interface{} ) -var ( - liveTest bool - envTestAPIToken string -) - -func init() { - envTestAPIToken = os.Getenv("LINODE_TOKEN") - - liveTest = len(envTestAPIToken) != 0 -} - -func restoreEnv() { - os.Setenv("LINODE_TOKEN", envTestAPIToken) -} +var envTest = tester.NewEnvTest("LINODE_TOKEN") func newMockServer(responses MockResponseMap) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -94,14 +82,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -135,9 +119,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("LINODE_TOKEN") - config := NewDefaultConfig() config.Token = test.apiKey @@ -156,7 +137,7 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestDNSProvider_Present(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() os.Setenv("LINODE_TOKEN", "testing") p, err := NewDNSProvider() @@ -246,7 +227,7 @@ func TestDNSProvider_Present(t *testing.T) { } func TestDNSProvider_CleanUp(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() os.Setenv("LINODE_TOKEN", "testing") p, err := NewDNSProvider() @@ -361,14 +342,14 @@ func TestDNSProvider_CleanUp(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("Skipping live test") } // TODO implement this test } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("Skipping live test") } // TODO implement this test diff --git a/providers/dns/namedotcom/namedotcom_test.go b/providers/dns/namedotcom/namedotcom_test.go index 144aa2ed..d2cae8a7 100644 --- a/providers/dns/namedotcom/namedotcom_test.go +++ b/providers/dns/namedotcom/namedotcom_test.go @@ -1,34 +1,17 @@ package namedotcom import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestUsername string - envTestAPIToken string - envTestDomain string -) - -func init() { - envTestUsername = os.Getenv("NAMECOM_USERNAME") - envTestAPIToken = os.Getenv("NAMECOM_API_TOKEN") - envTestDomain = os.Getenv("NAMEDOTCOM_DOMAIN") - - if len(envTestAPIToken) > 0 && len(envTestUsername) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("NAMECOM_USERNAME", envTestUsername) - os.Setenv("NAMECOM_API_TOKEN", envTestAPIToken) -} +var envTest = tester.NewEnvTest( + "NAMECOM_USERNAME", + "NAMECOM_API_TOKEN"). + WithDomain("NAMEDOTCOM_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -126,10 +105,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("NAMECOM_USERNAME") - os.Unsetenv("NAMECOM_API_TOKEN") - config := NewDefaultConfig() config.Username = test.username config.APIToken = test.apiToken @@ -149,29 +124,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/netcup/client_test.go b/providers/dns/netcup/client_test.go index 8b6cdc4f..6fcf78cc 100644 --- a/providers/dns/netcup/client_test.go +++ b/providers/dns/netcup/client_test.go @@ -11,12 +11,17 @@ import ( ) func TestLiveClientAuth(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } // Setup - client := NewClient(envTestCustomerNumber, envTestAPIKey, envTestAPIPassword) + envTest.RestoreEnv() + + client := NewClient( + envTest.GetValue("NETCUP_CUSTOMER_NUMBER"), + envTest.GetValue("NETCUP_API_KEY"), + envTest.GetValue("NETCUP_API_PASSWORD")) for i := 1; i < 4; i++ { i := i @@ -34,17 +39,22 @@ func TestLiveClientAuth(t *testing.T) { } func TestLiveClientGetDnsRecords(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - client := NewClient(envTestCustomerNumber, envTestAPIKey, envTestAPIPassword) - // Setup + envTest.RestoreEnv() + + client := NewClient( + envTest.GetValue("NETCUP_CUSTOMER_NUMBER"), + envTest.GetValue("NETCUP_API_KEY"), + envTest.GetValue("NETCUP_API_PASSWORD")) + sessionID, err := client.Login() require.NoError(t, err) - fqdn, _, _ := acme.DNS01Record(envTestDomain, "123d==") + fqdn, _, _ := acme.DNS01Record(envTest.GetDomain(), "123d==") zone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers) require.NoError(t, err, "error finding DNSZone") @@ -61,17 +71,22 @@ func TestLiveClientGetDnsRecords(t *testing.T) { } func TestLiveClientUpdateDnsRecord(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } // Setup - client := NewClient(envTestCustomerNumber, envTestAPIKey, envTestAPIPassword) + envTest.RestoreEnv() + + client := NewClient( + envTest.GetValue("NETCUP_CUSTOMER_NUMBER"), + envTest.GetValue("NETCUP_API_KEY"), + envTest.GetValue("NETCUP_API_PASSWORD")) sessionID, err := client.Login() require.NoError(t, err) - fqdn, _, _ := acme.DNS01Record(envTestDomain, "123d==") + fqdn, _, _ := acme.DNS01Record(envTest.GetDomain(), "123d==") zone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers) require.NoError(t, err, fmt.Errorf("error finding DNSZone, %v", err)) @@ -100,7 +115,7 @@ func TestLiveClientUpdateDnsRecord(t *testing.T) { records[recordIdx].DeleteRecord = true // Tear down - err = client.UpdateDNSRecord(sessionID, envTestDomain, records[recordIdx]) + err = client.UpdateDNSRecord(sessionID, envTest.GetDomain(), records[recordIdx]) require.NoError(t, err, "Did not remove record! Please do so yourself.") err = client.Logout(sessionID) diff --git a/providers/dns/netcup/netcup_test.go b/providers/dns/netcup/netcup_test.go index 4460e0ca..d7c06294 100644 --- a/providers/dns/netcup/netcup_test.go +++ b/providers/dns/netcup/netcup_test.go @@ -2,37 +2,18 @@ package netcup import ( "fmt" - "os" "testing" "github.com/stretchr/testify/require" "github.com/xenolf/lego/acme" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestCustomerNumber string - envTestAPIKey string - envTestAPIPassword string - envTestDomain string -) - -func init() { - envTestCustomerNumber = os.Getenv("NETCUP_CUSTOMER_NUMBER") - envTestAPIKey = os.Getenv("NETCUP_API_KEY") - envTestAPIPassword = os.Getenv("NETCUP_API_PASSWORD") - envTestDomain = os.Getenv("NETCUP_DOMAIN") - - if len(envTestCustomerNumber) > 0 && len(envTestAPIKey) > 0 && len(envTestAPIPassword) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("NETCUP_CUSTOMER_NUMBER", envTestCustomerNumber) - os.Setenv("NETCUP_API_KEY", envTestAPIKey) - os.Setenv("NETCUP_API_PASSWORD", envTestAPIPassword) -} +var envTest = tester.NewEnvTest( + "NETCUP_CUSTOMER_NUMBER", + "NETCUP_API_KEY", + "NETCUP_API_PASSWORD"). + WithDomain("NETCUP_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -88,14 +69,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -154,11 +131,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("NETCUP_CUSTOMER_NUMBER") - os.Unsetenv("NETCUP_API_KEY") - os.Unsetenv("NETCUP_API_PASSWORD") - config := NewDefaultConfig() config.Customer = test.customer config.Key = test.key @@ -179,15 +151,15 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresentAndCleanup(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() p, err := NewDNSProvider() require.NoError(t, err) - fqdn, _, _ := acme.DNS01Record(envTestDomain, "123d==") + fqdn, _, _ := acme.DNS01Record(envTest.GetDomain(), "123d==") zone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers) require.NoError(t, err, "error finding DNSZone") diff --git a/providers/dns/nifcloud/nifcloud_test.go b/providers/dns/nifcloud/nifcloud_test.go index d6602b88..d8b7a255 100644 --- a/providers/dns/nifcloud/nifcloud_test.go +++ b/providers/dns/nifcloud/nifcloud_test.go @@ -1,34 +1,17 @@ package nifcloud import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAccessKey string - envTestSecretKey string - envTestDomain string -) - -func init() { - envTestAccessKey = os.Getenv("NIFCLOUD_ACCESS_KEY_ID") - envTestSecretKey = os.Getenv("NIFCLOUD_SECRET_ACCESS_KEY") - envTestDomain = os.Getenv("NIFCLOUD_DOMAIN") - - if len(envTestAccessKey) > 0 && len(envTestSecretKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("NIFCLOUD_ACCESS_KEY_ID", envTestAccessKey) - os.Setenv("NIFCLOUD_SECRET_ACCESS_KEY", envTestSecretKey) -} +var envTest = tester.NewEnvTest( + "NIFCLOUD_ACCESS_KEY_ID", + "NIFCLOUD_SECRET_ACCESS_KEY"). + WithDomain("NIFCLOUD_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -124,10 +103,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("NIFCLOUD_ACCESS_KEY_ID") - os.Unsetenv("NIFCLOUD_SECRET_ACCESS_KEY") - config := NewDefaultConfig() config.AccessKey = test.accessKey config.SecretKey = test.secretKey @@ -147,29 +122,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/ns1/ns1_test.go b/providers/dns/ns1/ns1_test.go index f1889ae5..16e3548b 100644 --- a/providers/dns/ns1/ns1_test.go +++ b/providers/dns/ns1/ns1_test.go @@ -1,31 +1,16 @@ package ns1 import ( - "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("NS1_API_KEY") - envTestDomain = os.Getenv("NS1_DOMAIN") - if len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("NS1_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest("NS1_API_KEY"). + WithDomain("NS1_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -50,14 +35,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -91,9 +72,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("NS1_API_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey @@ -163,29 +141,29 @@ func Test_getAuthZone(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/otc/mock_test.go b/providers/dns/otc/mock_test.go index 127dd5cc..8eab3563 100644 --- a/providers/dns/otc/mock_test.go +++ b/providers/dns/otc/mock_test.go @@ -10,39 +10,37 @@ import ( "github.com/stretchr/testify/assert" ) -var fakeOTCUserName = "test" -var fakeOTCPassword = "test" -var fakeOTCDomainName = "test" -var fakeOTCProjectName = "test" var fakeOTCToken = "62244bc21da68d03ebac94e6636ff01f" -// DNSMock mock -type DNSMock struct { +// DNSServerMock mock +type DNSServerMock struct { t *testing.T - Server *httptest.Server + server *httptest.Server Mux *http.ServeMux } -// NewDNSMock create a new DNSMock -func NewDNSMock(t *testing.T) *DNSMock { - return &DNSMock{ - t: t, +// NewDNSServerMock create a new DNSServerMock +func NewDNSServerMock(t *testing.T) *DNSServerMock { + mux := http.NewServeMux() + + return &DNSServerMock{ + t: t, + server: httptest.NewServer(mux), + Mux: mux, } } -// Setup creates the mock server -func (m *DNSMock) Setup() { - m.Mux = http.NewServeMux() - m.Server = httptest.NewServer(m.Mux) +func (m *DNSServerMock) GetServerURL() string { + return m.server.URL } // ShutdownServer creates the mock server -func (m *DNSMock) ShutdownServer() { - m.Server.Close() +func (m *DNSServerMock) ShutdownServer() { + m.server.Close() } // HandleAuthSuccessfully Handle auth successfully -func (m *DNSMock) HandleAuthSuccessfully() { +func (m *DNSServerMock) HandleAuthSuccessfully() { m.Mux.HandleFunc("/v3/auth/token", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Subject-Token", fakeOTCToken) @@ -64,99 +62,101 @@ func (m *DNSMock) HandleAuthSuccessfully() { ] } ] - }}`, m.Server.URL) + }}`, m.server.URL) }) } // HandleListZonesSuccessfully Handle list zones successfully -func (m *DNSMock) HandleListZonesSuccessfully() { +func (m *DNSServerMock) HandleListZonesSuccessfully() { m.Mux.HandleFunc("/v2/zones", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, r.Method, http.MethodGet) + assert.Equal(m.t, r.URL.Path, "/v2/zones") + assert.Equal(m.t, r.URL.RawQuery, "name=example.com.") + assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + fmt.Fprintf(w, `{ "zones":[{ "id":"123123" }]} `) - assert.Equal(m.t, r.Method, http.MethodGet) - assert.Equal(m.t, r.URL.Path, "/v2/zones") - assert.Equal(m.t, r.URL.RawQuery, "name=example.com.") - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") }) } // HandleListZonesEmpty Handle list zones empty -func (m *DNSMock) HandleListZonesEmpty() { +func (m *DNSServerMock) HandleListZonesEmpty() { m.Mux.HandleFunc("/v2/zones", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, `{ - "zones":[ - ]} - `) - assert.Equal(m.t, r.Method, http.MethodGet) assert.Equal(m.t, r.URL.Path, "/v2/zones") assert.Equal(m.t, r.URL.RawQuery, "name=example.com.") assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + + fmt.Fprintf(w, `{ + "zones":[ + ]} + `) }) } // HandleDeleteRecordsetsSuccessfully Handle delete recordsets successfully -func (m *DNSMock) HandleDeleteRecordsetsSuccessfully() { +func (m *DNSServerMock) HandleDeleteRecordsetsSuccessfully() { m.Mux.HandleFunc("/v2/zones/123123/recordsets/321321", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, r.Method, http.MethodDelete) + assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets/321321") + assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + fmt.Fprintf(w, `{ "zones":[{ "id":"123123" }]} `) - - assert.Equal(m.t, r.Method, http.MethodDelete) - assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets/321321") - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") }) } // HandleListRecordsetsEmpty Handle list recordsets empty -func (m *DNSMock) HandleListRecordsetsEmpty() { +func (m *DNSServerMock) HandleListRecordsetsEmpty() { m.Mux.HandleFunc("/v2/zones/123123/recordsets", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets") + assert.Equal(m.t, r.URL.RawQuery, "type=TXT&name=_acme-challenge.example.com.") + fmt.Fprintf(w, `{ "recordsets":[ ]} `) - - assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets") - assert.Equal(m.t, r.URL.RawQuery, "type=TXT&name=_acme-challenge.example.com.") }) } // HandleListRecordsetsSuccessfully Handle list recordsets successfully -func (m *DNSMock) HandleListRecordsetsSuccessfully() { +func (m *DNSServerMock) HandleListRecordsetsSuccessfully() { m.Mux.HandleFunc("/v2/zones/123123/recordsets", func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { + assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets") + assert.Equal(m.t, r.URL.RawQuery, "type=TXT&name=_acme-challenge.example.com.") + assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + fmt.Fprintf(w, `{ "recordsets":[{ "id":"321321" }]} `) + return + } - assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets") - assert.Equal(m.t, r.URL.RawQuery, "type=TXT&name=_acme-challenge.example.com.") + if r.Method == http.MethodPost { + assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") - } else if r.Method == http.MethodPost { body, err := ioutil.ReadAll(r.Body) - assert.Nil(m.t, err) exceptedString := "{\"name\":\"_acme-challenge.example.com.\",\"description\":\"Added TXT record for ACME dns-01 challenge using lego client\",\"type\":\"TXT\",\"ttl\":300,\"records\":[\"\\\"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI\\\"\"]}" assert.Equal(m.t, string(body), exceptedString) - fmt.Fprintf(w, `{ "recordsets":[{ "id":"321321" }]} `) - - } else { - m.t.Errorf("Expected method to be 'GET' or 'POST' but got '%s'", r.Method) + return } - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + http.Error(w, fmt.Sprintf("Expected method to be 'GET' or 'POST' but got '%s'", r.Method), http.StatusBadRequest) }) } diff --git a/providers/dns/otc/otc_test.go b/providers/dns/otc/otc_test.go index a83694c0..001cae1d 100644 --- a/providers/dns/otc/otc_test.go +++ b/providers/dns/otc/otc_test.go @@ -5,125 +5,135 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/xenolf/lego/platform/tester" ) -type TestSuite struct { +type OTCSuite struct { suite.Suite - Mock *DNSMock + Mock *DNSServerMock + envTest *tester.EnvTest } -func (s *TestSuite) TearDownSuite() { +func (s *OTCSuite) SetupTest() { + s.Mock = NewDNSServerMock(s.T()) + s.Mock.HandleAuthSuccessfully() + s.envTest = tester.NewEnvTest( + "OTC_DOMAIN_NAME", + "OTC_USER_NAME", + "OTC_PASSWORD", + "OTC_PROJECT_NAME", + "OTC_IDENTITY_ENDPOINT", + ) +} + +func (s *OTCSuite) TearDownTest() { + s.envTest.RestoreEnv() s.Mock.ShutdownServer() } -func (s *TestSuite) SetupTest() { - s.Mock = NewDNSMock(s.T()) - s.Mock.Setup() - s.Mock.HandleAuthSuccessfully() -} - func TestTestSuite(t *testing.T) { - suite.Run(t, new(TestSuite)) + suite.Run(t, new(OTCSuite)) } -func (s *TestSuite) createDNSProvider() (*DNSProvider, error) { - url := fmt.Sprintf("%s/v3/auth/token", s.Mock.Server.URL) - +func (s *OTCSuite) createDNSProvider() (*DNSProvider, error) { config := NewDefaultConfig() - config.UserName = fakeOTCUserName - config.Password = fakeOTCPassword - config.DomainName = fakeOTCDomainName - config.ProjectName = fakeOTCProjectName - config.IdentityEndpoint = url + config.UserName = "UserName" + config.Password = "Password" + config.DomainName = "DomainName" + config.ProjectName = "ProjectName" + config.IdentityEndpoint = fmt.Sprintf("%s/v3/auth/token", s.Mock.GetServerURL()) return NewDNSProviderConfig(config) } -func (s *TestSuite) TestLogin() { +func (s *OTCSuite) TestLogin() { provider, err := s.createDNSProvider() - require.NoError(s.T(), err) + s.Require().NoError(err) err = provider.loginRequest() - require.NoError(s.T(), err) - assert.Equal(s.T(), provider.baseURL, fmt.Sprintf("%s/v2", s.Mock.Server.URL)) - assert.Equal(s.T(), fakeOTCToken, provider.token) + s.Require().NoError(err) + + s.Equal(provider.baseURL, fmt.Sprintf("%s/v2", s.Mock.GetServerURL())) + s.Equal(fakeOTCToken, provider.token) } -func (s *TestSuite) TestLoginEnv() { - defer os.Clearenv() +func (s *OTCSuite) TestLoginEnv() { + s.envTest.ClearEnv() - os.Setenv("OTC_DOMAIN_NAME", "unittest1") - os.Setenv("OTC_USER_NAME", "unittest2") - os.Setenv("OTC_PASSWORD", "unittest3") - os.Setenv("OTC_PROJECT_NAME", "unittest4") - os.Setenv("OTC_IDENTITY_ENDPOINT", "unittest5") + s.envTest.Apply(map[string]string{ + "OTC_DOMAIN_NAME": "unittest1", + "OTC_USER_NAME": "unittest2", + "OTC_PASSWORD": "unittest3", + "OTC_PROJECT_NAME": "unittest4", + "OTC_IDENTITY_ENDPOINT": "unittest5", + }) provider, err := NewDNSProvider() - require.NoError(s.T(), err) - assert.Equal(s.T(), provider.config.DomainName, "unittest1") - assert.Equal(s.T(), provider.config.UserName, "unittest2") - assert.Equal(s.T(), provider.config.Password, "unittest3") - assert.Equal(s.T(), provider.config.ProjectName, "unittest4") - assert.Equal(s.T(), provider.config.IdentityEndpoint, "unittest5") + s.Require().NoError(err) + + s.Equal(provider.config.DomainName, "unittest1") + s.Equal(provider.config.UserName, "unittest2") + s.Equal(provider.config.Password, "unittest3") + s.Equal(provider.config.ProjectName, "unittest4") + s.Equal(provider.config.IdentityEndpoint, "unittest5") os.Setenv("OTC_IDENTITY_ENDPOINT", "") provider, err = NewDNSProvider() - require.NoError(s.T(), err) - assert.Equal(s.T(), provider.config.IdentityEndpoint, "https://iam.eu-de.otc.t-systems.com:443/v3/auth/tokens") + s.Require().NoError(err) + + s.Equal(provider.config.IdentityEndpoint, "https://iam.eu-de.otc.t-systems.com:443/v3/auth/tokens") } -func (s *TestSuite) TestLoginEnvEmpty() { - defer os.Clearenv() +func (s *OTCSuite) TestLoginEnvEmpty() { + s.envTest.ClearEnv() _, err := NewDNSProvider() - assert.EqualError(s.T(), err, "otc: some credentials information are missing: OTC_DOMAIN_NAME,OTC_USER_NAME,OTC_PASSWORD,OTC_PROJECT_NAME") + s.EqualError(err, "otc: some credentials information are missing: OTC_DOMAIN_NAME,OTC_USER_NAME,OTC_PASSWORD,OTC_PROJECT_NAME") } -func (s *TestSuite) TestDNSProvider_Present() { +func (s *OTCSuite) TestDNSProvider_Present() { s.Mock.HandleListZonesSuccessfully() s.Mock.HandleListRecordsetsSuccessfully() provider, err := s.createDNSProvider() - require.NoError(s.T(), err) + s.Require().NoError(err) err = provider.Present("example.com", "", "foobar") - require.NoError(s.T(), err) + s.Require().NoError(err) } -func (s *TestSuite) TestDNSProvider_Present_EmptyZone() { +func (s *OTCSuite) TestDNSProvider_Present_EmptyZone() { s.Mock.HandleListZonesEmpty() s.Mock.HandleListRecordsetsSuccessfully() provider, err := s.createDNSProvider() - require.NoError(s.T(), err) + s.Require().NoError(err) err = provider.Present("example.com", "", "foobar") - assert.NotNil(s.T(), err) + s.NotNil(err) } -func (s *TestSuite) TestDNSProvider_CleanUp() { +func (s *OTCSuite) TestDNSProvider_CleanUp() { s.Mock.HandleListZonesSuccessfully() s.Mock.HandleListRecordsetsSuccessfully() s.Mock.HandleDeleteRecordsetsSuccessfully() provider, err := s.createDNSProvider() - require.NoError(s.T(), err) + s.Require().NoError(err) err = provider.CleanUp("example.com", "", "foobar") - require.NoError(s.T(), err) + s.Require().NoError(err) } -func (s *TestSuite) TestDNSProvider_CleanUp_EmptyRecordset() { +func (s *OTCSuite) TestDNSProvider_CleanUp_EmptyRecordset() { s.Mock.HandleListZonesSuccessfully() s.Mock.HandleListRecordsetsEmpty() provider, err := s.createDNSProvider() - require.NoError(s.T(), err) + s.Require().NoError(err) err = provider.CleanUp("example.com", "", "foobar") - require.Error(s.T(), err) + s.Require().Error(err) } diff --git a/providers/dns/ovh/ovh_test.go b/providers/dns/ovh/ovh_test.go index c6af2fe6..574fad9c 100644 --- a/providers/dns/ovh/ovh_test.go +++ b/providers/dns/ovh/ovh_test.go @@ -1,38 +1,19 @@ package ovh import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIEndpoint string - envTestApplicationKey string - envTestApplicationSecret string - envTestConsumerKey string - envTestDomain string -) - -func init() { - envTestAPIEndpoint = os.Getenv("OVH_ENDPOINT") - envTestApplicationKey = os.Getenv("OVH_APPLICATION_KEY") - envTestApplicationSecret = os.Getenv("OVH_APPLICATION_SECRET") - envTestConsumerKey = os.Getenv("OVH_CONSUMER_KEY") - envTestDomain = os.Getenv("OVH_DOMAIN") - - liveTest = len(envTestAPIEndpoint) > 0 && len(envTestApplicationKey) > 0 && len(envTestApplicationSecret) > 0 && len(envTestConsumerKey) > 0 -} - -func restoreEnv() { - os.Setenv("OVH_ENDPOINT", envTestAPIEndpoint) - os.Setenv("OVH_APPLICATION_KEY", envTestApplicationKey) - os.Setenv("OVH_APPLICATION_SECRET", envTestApplicationSecret) - os.Setenv("OVH_CONSUMER_KEY", envTestConsumerKey) -} +var envTest = tester.NewEnvTest( + "OVH_ENDPOINT", + "OVH_APPLICATION_KEY", + "OVH_APPLICATION_SECRET", + "OVH_CONSUMER_KEY"). + WithDomain("OVH_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -113,14 +94,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -201,12 +178,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("OVH_ENDPOINT") - os.Unsetenv("OVH_APPLICATION_KEY") - os.Unsetenv("OVH_APPLICATION_SECRET") - os.Unsetenv("OVH_CONSUMER_KEY") - config := NewDefaultConfig() config.APIEndpoint = test.apiEndpoint config.ApplicationKey = test.applicationKey @@ -229,29 +200,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/pdns/pdns_test.go b/providers/dns/pdns/pdns_test.go index 392b7099..632582f1 100644 --- a/providers/dns/pdns/pdns_test.go +++ b/providers/dns/pdns/pdns_test.go @@ -2,33 +2,16 @@ package pdns import ( "net/url" - "os" "testing" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIURL *url.URL - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestAPIURL, _ = url.Parse(os.Getenv("PDNS_API_URL")) - envTestAPIKey = os.Getenv("PDNS_API_KEY") - envTestDomain = os.Getenv("PDNS_DOMAIN") - - if len(envTestAPIURL.String()) > 0 && len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("PDNS_API_URL", envTestAPIURL.String()) - os.Setenv("PDNS_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest( + "PDNS_API_URL", + "PDNS_API_KEY"). + WithDomain("PDNS_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -130,10 +109,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("PDNS_API_KEY") - os.Unsetenv("PDNS_API_URL") - config := NewDefaultConfig() config.APIKey = test.apiKey config.Host = test.host @@ -152,17 +127,17 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresentAndCleanup(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/rackspace/rackspace_test.go b/providers/dns/rackspace/rackspace_test.go index f4fa288f..874e6721 100644 --- a/providers/dns/rackspace/rackspace_test.go +++ b/providers/dns/rackspace/rackspace_test.go @@ -5,31 +5,19 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "os" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestUser string - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestUser = os.Getenv("RACKSPACE_USER") - envTestAPIKey = os.Getenv("RACKSPACE_API_KEY") - envTestDomain = os.Getenv("RACKSPACE_DOMAIN") - - if len(envTestUser) > 0 && len(envTestAPIKey) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} +var envTest = tester.NewEnvTest( + "RACKSPACE_USER", + "RACKSPACE_API_KEY"). + WithDomain("RACKSPACE_DOMAIN") func TestNewDNSProviderConfig(t *testing.T) { config, tearDown := setupTest() @@ -72,38 +60,42 @@ func TestDNSProvider_CleanUp(t *testing.T) { } func TestLiveNewDNSProvider_ValidEnv(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) + assert.Contains(t, provider.cloudDNSEndpoint, "https://dns.api.rackspacecloud.com/v1.0/", "The endpoint URL should contain the base") } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "112233445566==") + err = provider.Present(envTest.GetDomain(), "", "112233445566==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - time.Sleep(time.Second * 15) - + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.CleanUp(envTestDomain, "", "112233445566==") + time.Sleep(15 * time.Second) + + err = provider.CleanUp(envTest.GetDomain(), "", "112233445566==") require.NoError(t, err) } diff --git a/providers/dns/route53/route53_integration_test.go b/providers/dns/route53/route53_integration_test.go index eede6e6c..acc301bc 100644 --- a/providers/dns/route53/route53_integration_test.go +++ b/providers/dns/route53/route53_integration_test.go @@ -7,19 +7,19 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/route53" "github.com/stretchr/testify/require" - "github.com/xenolf/lego/platform/config/env" ) func TestLiveTTL(t *testing.T) { - config, err := env.Get("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION", "R53_DOMAIN") - if err != nil { - t.Skip(err.Error()) + if !envTest.IsLiveTest() { + t.Skip("skipping live test") } + envTest.RestoreEnv() + provider, err := NewDNSProvider() require.NoError(t, err) - domain := config["R53_DOMAIN"] + domain := envTest.GetDomain() err = provider.Present(domain, "foo", "bar") require.NoError(t, err) diff --git a/providers/dns/route53/route53_test.go b/providers/dns/route53/route53_test.go index cc840685..228cfbfa 100644 --- a/providers/dns/route53/route53_test.go +++ b/providers/dns/route53/route53_test.go @@ -12,55 +12,20 @@ import ( "github.com/aws/aws-sdk-go/service/route53" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - envTestAwsSecretAccessKey string - envTestAwsAccessKeyID string - envTestAwsRegion string - envTestAwsHostedZoneID string - - envTestAwsMaxRetries string - envTestAwsTTL string - envTestAwsPropagationTimeout string - envTestAwsPollingInterval string -) - -func init() { - envTestAwsAccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID") - envTestAwsSecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY") - envTestAwsRegion = os.Getenv("AWS_REGION") - envTestAwsHostedZoneID = os.Getenv("AWS_HOSTED_ZONE_ID") - - envTestAwsMaxRetries = os.Getenv("AWS_MAX_RETRIES") - envTestAwsTTL = os.Getenv("AWS_TTL") - envTestAwsPropagationTimeout = os.Getenv("AWS_PROPAGATION_TIMEOUT") - envTestAwsPollingInterval = os.Getenv("AWS_POLLING_INTERVAL") -} - -func restoreEnv() { - os.Setenv("AWS_ACCESS_KEY_ID", envTestAwsAccessKeyID) - os.Setenv("AWS_SECRET_ACCESS_KEY", envTestAwsSecretAccessKey) - os.Setenv("AWS_REGION", envTestAwsRegion) - os.Setenv("AWS_HOSTED_ZONE_ID", envTestAwsHostedZoneID) - - os.Setenv("AWS_MAX_RETRIES", envTestAwsMaxRetries) - os.Setenv("AWS_TTL", envTestAwsTTL) - os.Setenv("AWS_PROPAGATION_TIMEOUT", envTestAwsPropagationTimeout) - os.Setenv("AWS_POLLING_INTERVAL", envTestAwsPollingInterval) -} - -func cleanEnv() { - os.Unsetenv("AWS_ACCESS_KEY_ID") - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - os.Unsetenv("AWS_REGION") - os.Unsetenv("AWS_HOSTED_ZONE_ID") - - os.Unsetenv("AWS_MAX_RETRIES") - os.Unsetenv("AWS_TTL") - os.Unsetenv("AWS_PROPAGATION_TIMEOUT") - os.Unsetenv("AWS_POLLING_INTERVAL") -} +var envTest = tester.NewEnvTest( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_HOSTED_ZONE_ID", + "AWS_MAX_RETRIES", + "AWS_TTL", + "AWS_PROPAGATION_TIMEOUT", + "AWS_POLLING_INTERVAL"). + WithDomain("R53_DOMAIN"). + WithLiveTestRequirements("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION", "R53_DOMAIN") func makeTestProvider(ts *httptest.Server) *DNSProvider { config := &aws.Config{ @@ -80,7 +45,9 @@ func makeTestProvider(ts *httptest.Server) *DNSProvider { } func Test_loadCredentials_FromEnv(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() + envTest.ClearEnv() + os.Setenv("AWS_ACCESS_KEY_ID", "123") os.Setenv("AWS_SECRET_ACCESS_KEY", "456") os.Setenv("AWS_REGION", "us-east-1") @@ -105,7 +72,9 @@ func Test_loadCredentials_FromEnv(t *testing.T) { } func Test_loadRegion_FromEnv(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() + envTest.ClearEnv() + os.Setenv("AWS_REGION", route53.CloudWatchRegionUsEast1) sess, err := session.NewSession(aws.NewConfig()) @@ -116,7 +85,8 @@ func Test_loadRegion_FromEnv(t *testing.T) { } func Test_getHostedZoneID_FromEnv(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() + envTest.ClearEnv() expectedZoneID := "zoneID" @@ -132,7 +102,7 @@ func Test_getHostedZoneID_FromEnv(t *testing.T) { } func TestNewDefaultConfig(t *testing.T) { - defer restoreEnv() + defer envTest.RestoreEnv() testCases := []struct { desc string @@ -169,7 +139,7 @@ func TestNewDefaultConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - cleanEnv() + envTest.ClearEnv() for key, value := range test.envVars { os.Setenv(key, value) } @@ -195,6 +165,8 @@ func TestDNSProvider_Present(t *testing.T) { ts := newMockServer(t, mockResponses) defer ts.Close() + defer envTest.RestoreEnv() + envTest.ClearEnv() provider := makeTestProvider(ts) domain := "example.com" diff --git a/providers/dns/sakuracloud/sakuracloud_test.go b/providers/dns/sakuracloud/sakuracloud_test.go index 5de74b0c..531a6e19 100644 --- a/providers/dns/sakuracloud/sakuracloud_test.go +++ b/providers/dns/sakuracloud/sakuracloud_test.go @@ -1,34 +1,17 @@ package sakuracloud import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAccessToken string - envTestAccessSecret string - envTestDomain string -) - -func init() { - envTestAccessToken = os.Getenv("SAKURACLOUD_ACCESS_TOKEN") - envTestAccessSecret = os.Getenv("SAKURACLOUD_ACCESS_TOKEN_SECRET") - envTestDomain = os.Getenv("SAKURACLOUD_DOMAIN") - - if len(envTestAccessToken) > 0 && len(envTestAccessSecret) > 0 && len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("SAKURACLOUD_ACCESS_TOKEN", envTestAccessToken) - os.Setenv("SAKURACLOUD_ACCESS_TOKEN_SECRET", envTestAccessSecret) -} +var envTest = tester.NewEnvTest( + "SAKURACLOUD_ACCESS_TOKEN", + "SAKURACLOUD_ACCESS_TOKEN_SECRET"). + WithDomain("SAKURACLOUD_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -71,14 +54,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -124,10 +103,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("SAKURACLOUD_ACCESS_TOKEN") - os.Unsetenv("SAKURACLOUD_ACCESS_TOKEN_SECRET") - config := NewDefaultConfig() config.Token = test.token config.Secret = test.secret @@ -147,29 +122,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/stackpath/stackpath_test.go b/providers/dns/stackpath/stackpath_test.go index 54ecd47d..d3f3dd1a 100644 --- a/providers/dns/stackpath/stackpath_test.go +++ b/providers/dns/stackpath/stackpath_test.go @@ -4,42 +4,19 @@ import ( "net/http" "net/http/httptest" "net/url" - "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestClientID string - envTestClientSecret string - envTestStackID string - envTestDomain string -) - -func init() { - envTestClientID = os.Getenv("STACKPATH_CLIENT_ID") - envTestClientSecret = os.Getenv("STACKPATH_CLIENT_SECRET") - envTestStackID = os.Getenv("STACKPATH_STACK_ID") - envTestDomain = os.Getenv("STACKPATH_DOMAIN") - - if len(envTestClientID) > 0 && - len(envTestClientSecret) > 0 && - len(envTestStackID) > 0 && - len(envTestDomain) > 0 { - liveTest = true - } -} - -func restoreEnv() { - os.Setenv("STACKPATH_CLIENT_ID", envTestClientID) - os.Setenv("STACKPATH_CLIENT_SECRET", envTestClientSecret) - os.Setenv("STACKPATH_STACK_ID", envTestStackID) - os.Setenv("STACKPATH_DOMAIN", envTestDomain) -} +var envTest = tester.NewEnvTest( + "STACKPATH_CLIENT_ID", + "STACKPATH_CLIENT_SECRET", + "STACKPATH_STACK_ID"). + WithDomain("STACKPATH_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -95,14 +72,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -289,29 +262,29 @@ func TestDNSProvider_getZones(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/providers/dns/vegadns/vegadns_test.go b/providers/dns/vegadns/vegadns_test.go index 50af3172..69888d7e 100644 --- a/providers/dns/vegadns/vegadns_test.go +++ b/providers/dns/vegadns/vegadns_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) const testDomain = "example.com" @@ -101,20 +102,26 @@ var jsonMap = map[string]string{ "recordDeleted": `{"status": "ok"}`, } +var envTest = tester.NewEnvTest("SECRET_VEGADNS_KEY", "SECRET_VEGADNS_SECRET", "VEGADNS_URL") + type muxCallback func() *http.ServeMux func TestNewDNSProvider_Fail(t *testing.T) { - os.Setenv("VEGADNS_URL", "") + defer envTest.RestoreEnv() + envTest.ClearEnv() + _, err := NewDNSProvider() assert.Error(t, err, "VEGADNS_URL env missing") } func TestDNSProvider_TimeoutSuccess(t *testing.T) { + defer envTest.RestoreEnv() + envTest.ClearEnv() + ts, err := startTestServer(muxSuccess) require.NoError(t, err) defer ts.Close() - defer os.Clearenv() provider, err := NewDNSProvider() require.NoError(t, err) @@ -148,11 +155,13 @@ func TestDNSProvider_Present(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { + defer envTest.RestoreEnv() + envTest.ClearEnv() + ts, err := startTestServer(test.callback) require.NoError(t, err) defer ts.Close() - defer os.Clearenv() provider, err := NewDNSProvider() require.NoError(t, err) @@ -191,11 +200,13 @@ func TestDNSProvider_CleanUp(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { + defer envTest.RestoreEnv() + envTest.ClearEnv() + ts, err := startTestServer(test.callback) require.NoError(t, err) defer ts.Close() - defer os.Clearenv() provider, err := NewDNSProvider() require.NoError(t, err) diff --git a/providers/dns/vultr/vultr_test.go b/providers/dns/vultr/vultr_test.go index 2d7fe58c..f73586a4 100644 --- a/providers/dns/vultr/vultr_test.go +++ b/providers/dns/vultr/vultr_test.go @@ -1,29 +1,15 @@ package vultr import ( - "os" "testing" "time" "github.com/stretchr/testify/require" + "github.com/xenolf/lego/platform/tester" ) -var ( - liveTest bool - envTestAPIKey string - envTestDomain string -) - -func init() { - envTestAPIKey = os.Getenv("VULTR_API_KEY") - envTestDomain = os.Getenv("VULTR_TEST_DOMAIN") - - liveTest = len(envTestAPIKey) > 0 && len(envTestDomain) > 0 -} - -func restoreEnv() { - os.Setenv("VULTR_API_KEY", envTestAPIKey) -} +var envTest = tester.NewEnvTest("VULTR_API_KEY"). + WithDomain("VULTR_TEST_DOMAIN") func TestNewDNSProvider(t *testing.T) { testCases := []struct { @@ -48,14 +34,10 @@ func TestNewDNSProvider(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - for key, value := range test.envVars { - if len(value) == 0 { - os.Unsetenv(key) - } else { - os.Setenv(key, value) - } - } + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.envVars) p, err := NewDNSProvider() @@ -89,9 +71,6 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - defer restoreEnv() - os.Unsetenv("VULTR_API_KEY") - config := NewDefaultConfig() config.APIKey = test.apiKey @@ -110,29 +89,29 @@ func TestNewDNSProviderConfig(t *testing.T) { } func TestLivePresent(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) - err = provider.Present(envTestDomain, "", "123d==") + err = provider.Present(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } func TestLiveCleanUp(t *testing.T) { - if !liveTest { + if !envTest.IsLiveTest() { t.Skip("skipping live test") } - restoreEnv() + envTest.RestoreEnv() provider, err := NewDNSProvider() require.NoError(t, err) time.Sleep(1 * time.Second) - err = provider.CleanUp(envTestDomain, "", "123d==") + err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } diff --git a/vendor/github.com/stretchr/objx/LICENSE b/vendor/github.com/stretchr/objx/LICENSE new file mode 100644 index 00000000..44d4d9d5 --- /dev/null +++ b/vendor/github.com/stretchr/objx/LICENSE @@ -0,0 +1,22 @@ +The MIT License + +Copyright (c) 2014 Stretchr, Inc. +Copyright (c) 2017-2018 objx contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/stretchr/objx/accessors.go b/vendor/github.com/stretchr/objx/accessors.go new file mode 100644 index 00000000..204356a2 --- /dev/null +++ b/vendor/github.com/stretchr/objx/accessors.go @@ -0,0 +1,148 @@ +package objx + +import ( + "regexp" + "strconv" + "strings" +) + +// arrayAccesRegexString is the regex used to extract the array number +// from the access path +const arrayAccesRegexString = `^(.+)\[([0-9]+)\]$` + +// arrayAccesRegex is the compiled arrayAccesRegexString +var arrayAccesRegex = regexp.MustCompile(arrayAccesRegexString) + +// Get gets the value using the specified selector and +// returns it inside a new Obj object. +// +// If it cannot find the value, Get will return a nil +// value inside an instance of Obj. +// +// Get can only operate directly on map[string]interface{} and []interface. +// +// Example +// +// To access the title of the third chapter of the second book, do: +// +// o.Get("books[1].chapters[2].title") +func (m Map) Get(selector string) *Value { + rawObj := access(m, selector, nil, false) + return &Value{data: rawObj} +} + +// Set sets the value using the specified selector and +// returns the object on which Set was called. +// +// Set can only operate directly on map[string]interface{} and []interface +// +// Example +// +// To set the title of the third chapter of the second book, do: +// +// o.Set("books[1].chapters[2].title","Time to Go") +func (m Map) Set(selector string, value interface{}) Map { + access(m, selector, value, true) + return m +} + +// access accesses the object using the selector and performs the +// appropriate action. +func access(current, selector, value interface{}, isSet bool) interface{} { + switch selector.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + if array, ok := current.([]interface{}); ok { + index := intFromInterface(selector) + if index >= len(array) { + return nil + } + return array[index] + } + return nil + + case string: + selStr := selector.(string) + selSegs := strings.SplitN(selStr, PathSeparator, 2) + thisSel := selSegs[0] + index := -1 + var err error + + if strings.Contains(thisSel, "[") { + arrayMatches := arrayAccesRegex.FindStringSubmatch(thisSel) + if len(arrayMatches) > 0 { + // Get the key into the map + thisSel = arrayMatches[1] + + // Get the index into the array at the key + index, err = strconv.Atoi(arrayMatches[2]) + + if err != nil { + // This should never happen. If it does, something has gone + // seriously wrong. Panic. + panic("objx: Array index is not an integer. Must use array[int].") + } + } + } + if curMap, ok := current.(Map); ok { + current = map[string]interface{}(curMap) + } + // get the object in question + switch current.(type) { + case map[string]interface{}: + curMSI := current.(map[string]interface{}) + if len(selSegs) <= 1 && isSet { + curMSI[thisSel] = value + return nil + } + current = curMSI[thisSel] + default: + current = nil + } + // do we need to access the item of an array? + if index > -1 { + if array, ok := current.([]interface{}); ok { + if index < len(array) { + current = array[index] + } else { + current = nil + } + } + } + if len(selSegs) > 1 { + current = access(current, selSegs[1], value, isSet) + } + } + return current +} + +// intFromInterface converts an interface object to the largest +// representation of an unsigned integer using a type switch and +// assertions +func intFromInterface(selector interface{}) int { + var value int + switch selector.(type) { + case int: + value = selector.(int) + case int8: + value = int(selector.(int8)) + case int16: + value = int(selector.(int16)) + case int32: + value = int(selector.(int32)) + case int64: + value = int(selector.(int64)) + case uint: + value = int(selector.(uint)) + case uint8: + value = int(selector.(uint8)) + case uint16: + value = int(selector.(uint16)) + case uint32: + value = int(selector.(uint32)) + case uint64: + value = int(selector.(uint64)) + default: + return 0 + } + return value +} diff --git a/vendor/github.com/stretchr/objx/constants.go b/vendor/github.com/stretchr/objx/constants.go new file mode 100644 index 00000000..f9eb42a2 --- /dev/null +++ b/vendor/github.com/stretchr/objx/constants.go @@ -0,0 +1,13 @@ +package objx + +const ( + // PathSeparator is the character used to separate the elements + // of the keypath. + // + // For example, `location.address.city` + PathSeparator string = "." + + // SignatureSeparator is the character that is used to + // separate the Base64 string from the security signature. + SignatureSeparator = "_" +) diff --git a/vendor/github.com/stretchr/objx/conversions.go b/vendor/github.com/stretchr/objx/conversions.go new file mode 100644 index 00000000..5e020f31 --- /dev/null +++ b/vendor/github.com/stretchr/objx/conversions.go @@ -0,0 +1,108 @@ +package objx + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/url" +) + +// JSON converts the contained object to a JSON string +// representation +func (m Map) JSON() (string, error) { + result, err := json.Marshal(m) + if err != nil { + err = errors.New("objx: JSON encode failed with: " + err.Error()) + } + return string(result), err +} + +// MustJSON converts the contained object to a JSON string +// representation and panics if there is an error +func (m Map) MustJSON() string { + result, err := m.JSON() + if err != nil { + panic(err.Error()) + } + return result +} + +// Base64 converts the contained object to a Base64 string +// representation of the JSON string representation +func (m Map) Base64() (string, error) { + var buf bytes.Buffer + + jsonData, err := m.JSON() + if err != nil { + return "", err + } + + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + _, err = encoder.Write([]byte(jsonData)) + if err != nil { + return "", err + } + _ = encoder.Close() + + return buf.String(), nil +} + +// MustBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and panics +// if there is an error +func (m Map) MustBase64() string { + result, err := m.Base64() + if err != nil { + panic(err.Error()) + } + return result +} + +// SignedBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and signs it +// using the provided key. +func (m Map) SignedBase64(key string) (string, error) { + base64, err := m.Base64() + if err != nil { + return "", err + } + + sig := HashWithKey(base64, key) + return base64 + SignatureSeparator + sig, nil +} + +// MustSignedBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and signs it +// using the provided key and panics if there is an error +func (m Map) MustSignedBase64(key string) string { + result, err := m.SignedBase64(key) + if err != nil { + panic(err.Error()) + } + return result +} + +/* + URL Query + ------------------------------------------------ +*/ + +// URLValues creates a url.Values object from an Obj. This +// function requires that the wrapped object be a map[string]interface{} +func (m Map) URLValues() url.Values { + vals := make(url.Values) + for k, v := range m { + //TODO: can this be done without sprintf? + vals.Set(k, fmt.Sprintf("%v", v)) + } + return vals +} + +// URLQuery gets an encoded URL query representing the given +// Obj. This function requires that the wrapped object be a +// map[string]interface{} +func (m Map) URLQuery() (string, error) { + return m.URLValues().Encode(), nil +} diff --git a/vendor/github.com/stretchr/objx/doc.go b/vendor/github.com/stretchr/objx/doc.go new file mode 100644 index 00000000..6d6af1a8 --- /dev/null +++ b/vendor/github.com/stretchr/objx/doc.go @@ -0,0 +1,66 @@ +/* +Objx - Go package for dealing with maps, slices, JSON and other data. + +Overview + +Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes +a powerful `Get` method (among others) that allows you to easily and quickly get +access to data within the map, without having to worry too much about type assertions, +missing data, default values etc. + +Pattern + +Objx uses a preditable pattern to make access data from within `map[string]interface{}` easy. +Call one of the `objx.` functions to create your `objx.Map` to get going: + + m, err := objx.FromJSON(json) + +NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, +the rest will be optimistic and try to figure things out without panicking. + +Use `Get` to access the value you're interested in. You can use dot and array +notation too: + + m.Get("places[0].latlng") + +Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type. + + if m.Get("code").IsStr() { // Your code... } + +Or you can just assume the type, and use one of the strong type methods to extract the real value: + + m.Get("code").Int() + +If there's no value there (or if it's the wrong type) then a default value will be returned, +or you can be explicit about the default value. + + Get("code").Int(-1) + +If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating, +manipulating and selecting that data. You can find out more by exploring the index below. + +Reading data + +A simple example of how to use Objx: + + // Use MustFromJSON to make an objx.Map from some JSON + m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`) + + // Get the details + name := m.Get("name").Str() + age := m.Get("age").Int() + + // Get their nickname (or use their name if they don't have one) + nickname := m.Get("nickname").Str(name) + +Ranging + +Since `objx.Map` is a `map[string]interface{}` you can treat it as such. +For example, to `range` the data, do what you would expect: + + m := objx.MustFromJSON(json) + for key, value := range m { + // Your code... + } +*/ +package objx diff --git a/vendor/github.com/stretchr/objx/map.go b/vendor/github.com/stretchr/objx/map.go new file mode 100644 index 00000000..406bc892 --- /dev/null +++ b/vendor/github.com/stretchr/objx/map.go @@ -0,0 +1,190 @@ +package objx + +import ( + "encoding/base64" + "encoding/json" + "errors" + "io/ioutil" + "net/url" + "strings" +) + +// MSIConvertable is an interface that defines methods for converting your +// custom types to a map[string]interface{} representation. +type MSIConvertable interface { + // MSI gets a map[string]interface{} (msi) representing the + // object. + MSI() map[string]interface{} +} + +// Map provides extended functionality for working with +// untyped data, in particular map[string]interface (msi). +type Map map[string]interface{} + +// Value returns the internal value instance +func (m Map) Value() *Value { + return &Value{data: m} +} + +// Nil represents a nil Map. +var Nil = New(nil) + +// New creates a new Map containing the map[string]interface{} in the data argument. +// If the data argument is not a map[string]interface, New attempts to call the +// MSI() method on the MSIConvertable interface to create one. +func New(data interface{}) Map { + if _, ok := data.(map[string]interface{}); !ok { + if converter, ok := data.(MSIConvertable); ok { + data = converter.MSI() + } else { + return nil + } + } + return Map(data.(map[string]interface{})) +} + +// MSI creates a map[string]interface{} and puts it inside a new Map. +// +// The arguments follow a key, value pattern. +// +// +// Returns nil if any key argument is non-string or if there are an odd number of arguments. +// +// Example +// +// To easily create Maps: +// +// m := objx.MSI("name", "Mat", "age", 29, "subobj", objx.MSI("active", true)) +// +// // creates an Map equivalent to +// m := objx.Map{"name": "Mat", "age": 29, "subobj": objx.Map{"active": true}} +func MSI(keyAndValuePairs ...interface{}) Map { + newMap := Map{} + keyAndValuePairsLen := len(keyAndValuePairs) + if keyAndValuePairsLen%2 != 0 { + return nil + } + for i := 0; i < keyAndValuePairsLen; i = i + 2 { + key := keyAndValuePairs[i] + value := keyAndValuePairs[i+1] + + // make sure the key is a string + keyString, keyStringOK := key.(string) + if !keyStringOK { + return nil + } + newMap[keyString] = value + } + return newMap +} + +// ****** Conversion Constructors + +// MustFromJSON creates a new Map containing the data specified in the +// jsonString. +// +// Panics if the JSON is invalid. +func MustFromJSON(jsonString string) Map { + o, err := FromJSON(jsonString) + if err != nil { + panic("objx: MustFromJSON failed with error: " + err.Error()) + } + return o +} + +// FromJSON creates a new Map containing the data specified in the +// jsonString. +// +// Returns an error if the JSON is invalid. +func FromJSON(jsonString string) (Map, error) { + var data interface{} + err := json.Unmarshal([]byte(jsonString), &data) + if err != nil { + return Nil, err + } + return New(data), nil +} + +// FromBase64 creates a new Obj containing the data specified +// in the Base64 string. +// +// The string is an encoded JSON string returned by Base64 +func FromBase64(base64String string) (Map, error) { + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64String)) + decoded, err := ioutil.ReadAll(decoder) + if err != nil { + return nil, err + } + return FromJSON(string(decoded)) +} + +// MustFromBase64 creates a new Obj containing the data specified +// in the Base64 string and panics if there is an error. +// +// The string is an encoded JSON string returned by Base64 +func MustFromBase64(base64String string) Map { + result, err := FromBase64(base64String) + if err != nil { + panic("objx: MustFromBase64 failed with error: " + err.Error()) + } + return result +} + +// FromSignedBase64 creates a new Obj containing the data specified +// in the Base64 string. +// +// The string is an encoded JSON string returned by SignedBase64 +func FromSignedBase64(base64String, key string) (Map, error) { + parts := strings.Split(base64String, SignatureSeparator) + if len(parts) != 2 { + return nil, errors.New("objx: Signed base64 string is malformed") + } + + sig := HashWithKey(parts[0], key) + if parts[1] != sig { + return nil, errors.New("objx: Signature for base64 data does not match") + } + return FromBase64(parts[0]) +} + +// MustFromSignedBase64 creates a new Obj containing the data specified +// in the Base64 string and panics if there is an error. +// +// The string is an encoded JSON string returned by Base64 +func MustFromSignedBase64(base64String, key string) Map { + result, err := FromSignedBase64(base64String, key) + if err != nil { + panic("objx: MustFromSignedBase64 failed with error: " + err.Error()) + } + return result +} + +// FromURLQuery generates a new Obj by parsing the specified +// query. +// +// For queries with multiple values, the first value is selected. +func FromURLQuery(query string) (Map, error) { + vals, err := url.ParseQuery(query) + if err != nil { + return nil, err + } + m := Map{} + for k, vals := range vals { + m[k] = vals[0] + } + return m, nil +} + +// MustFromURLQuery generates a new Obj by parsing the specified +// query. +// +// For queries with multiple values, the first value is selected. +// +// Panics if it encounters an error +func MustFromURLQuery(query string) Map { + o, err := FromURLQuery(query) + if err != nil { + panic("objx: MustFromURLQuery failed with error: " + err.Error()) + } + return o +} diff --git a/vendor/github.com/stretchr/objx/mutations.go b/vendor/github.com/stretchr/objx/mutations.go new file mode 100644 index 00000000..c3400a3f --- /dev/null +++ b/vendor/github.com/stretchr/objx/mutations.go @@ -0,0 +1,77 @@ +package objx + +// Exclude returns a new Map with the keys in the specified []string +// excluded. +func (m Map) Exclude(exclude []string) Map { + excluded := make(Map) + for k, v := range m { + if !contains(exclude, k) { + excluded[k] = v + } + } + return excluded +} + +// Copy creates a shallow copy of the Obj. +func (m Map) Copy() Map { + copied := Map{} + for k, v := range m { + copied[k] = v + } + return copied +} + +// Merge blends the specified map with a copy of this map and returns the result. +// +// Keys that appear in both will be selected from the specified map. +// This method requires that the wrapped object be a map[string]interface{} +func (m Map) Merge(merge Map) Map { + return m.Copy().MergeHere(merge) +} + +// MergeHere blends the specified map with this map and returns the current map. +// +// Keys that appear in both will be selected from the specified map. The original map +// will be modified. This method requires that +// the wrapped object be a map[string]interface{} +func (m Map) MergeHere(merge Map) Map { + for k, v := range merge { + m[k] = v + } + return m +} + +// Transform builds a new Obj giving the transformer a chance +// to change the keys and values as it goes. This method requires that +// the wrapped object be a map[string]interface{} +func (m Map) Transform(transformer func(key string, value interface{}) (string, interface{})) Map { + newMap := Map{} + for k, v := range m { + modifiedKey, modifiedVal := transformer(k, v) + newMap[modifiedKey] = modifiedVal + } + return newMap +} + +// TransformKeys builds a new map using the specified key mapping. +// +// Unspecified keys will be unaltered. +// This method requires that the wrapped object be a map[string]interface{} +func (m Map) TransformKeys(mapping map[string]string) Map { + return m.Transform(func(key string, value interface{}) (string, interface{}) { + if newKey, ok := mapping[key]; ok { + return newKey, value + } + return key, value + }) +} + +// Checks if a string slice contains a string +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/vendor/github.com/stretchr/objx/security.go b/vendor/github.com/stretchr/objx/security.go new file mode 100644 index 00000000..692be8e2 --- /dev/null +++ b/vendor/github.com/stretchr/objx/security.go @@ -0,0 +1,12 @@ +package objx + +import ( + "crypto/sha1" + "encoding/hex" +) + +// HashWithKey hashes the specified string using the security key +func HashWithKey(data, key string) string { + d := sha1.Sum([]byte(data + ":" + key)) + return hex.EncodeToString(d[:]) +} diff --git a/vendor/github.com/stretchr/objx/tests.go b/vendor/github.com/stretchr/objx/tests.go new file mode 100644 index 00000000..d9e0b479 --- /dev/null +++ b/vendor/github.com/stretchr/objx/tests.go @@ -0,0 +1,17 @@ +package objx + +// Has gets whether there is something at the specified selector +// or not. +// +// If m is nil, Has will always return false. +func (m Map) Has(selector string) bool { + if m == nil { + return false + } + return !m.Get(selector).IsNil() +} + +// IsNil gets whether the data is nil or not. +func (v *Value) IsNil() bool { + return v == nil || v.data == nil +} diff --git a/vendor/github.com/stretchr/objx/type_specific_codegen.go b/vendor/github.com/stretchr/objx/type_specific_codegen.go new file mode 100644 index 00000000..202a91f8 --- /dev/null +++ b/vendor/github.com/stretchr/objx/type_specific_codegen.go @@ -0,0 +1,2501 @@ +package objx + +/* + Inter (interface{} and []interface{}) +*/ + +// Inter gets the value as a interface{}, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Inter(optionalDefault ...interface{}) interface{} { + if s, ok := v.data.(interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInter gets the value as a interface{}. +// +// Panics if the object is not a interface{}. +func (v *Value) MustInter() interface{} { + return v.data.(interface{}) +} + +// InterSlice gets the value as a []interface{}, returns the optionalDefault +// value or nil if the value is not a []interface{}. +func (v *Value) InterSlice(optionalDefault ...[]interface{}) []interface{} { + if s, ok := v.data.([]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInterSlice gets the value as a []interface{}. +// +// Panics if the object is not a []interface{}. +func (v *Value) MustInterSlice() []interface{} { + return v.data.([]interface{}) +} + +// IsInter gets whether the object contained is a interface{} or not. +func (v *Value) IsInter() bool { + _, ok := v.data.(interface{}) + return ok +} + +// IsInterSlice gets whether the object contained is a []interface{} or not. +func (v *Value) IsInterSlice() bool { + _, ok := v.data.([]interface{}) + return ok +} + +// EachInter calls the specified callback for each object +// in the []interface{}. +// +// Panics if the object is the wrong type. +func (v *Value) EachInter(callback func(int, interface{}) bool) *Value { + for index, val := range v.MustInterSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInter uses the specified decider function to select items +// from the []interface{}. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInter(decider func(int, interface{}) bool) *Value { + var selected []interface{} + v.EachInter(func(index int, val interface{}) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInter uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]interface{}. +func (v *Value) GroupInter(grouper func(int, interface{}) string) *Value { + groups := make(map[string][]interface{}) + v.EachInter(func(index int, val interface{}) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]interface{}, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInter uses the specified function to replace each interface{}s +// by iterating each item. The data in the returned result will be a +// []interface{} containing the replaced items. +func (v *Value) ReplaceInter(replacer func(int, interface{}) interface{}) *Value { + arr := v.MustInterSlice() + replaced := make([]interface{}, len(arr)) + v.EachInter(func(index int, val interface{}) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInter uses the specified collector function to collect a value +// for each of the interface{}s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInter(collector func(int, interface{}) interface{}) *Value { + arr := v.MustInterSlice() + collected := make([]interface{}, len(arr)) + v.EachInter(func(index int, val interface{}) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + MSI (map[string]interface{} and []map[string]interface{}) +*/ + +// MSI gets the value as a map[string]interface{}, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) MSI(optionalDefault ...map[string]interface{}) map[string]interface{} { + if s, ok := v.data.(map[string]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustMSI gets the value as a map[string]interface{}. +// +// Panics if the object is not a map[string]interface{}. +func (v *Value) MustMSI() map[string]interface{} { + return v.data.(map[string]interface{}) +} + +// MSISlice gets the value as a []map[string]interface{}, returns the optionalDefault +// value or nil if the value is not a []map[string]interface{}. +func (v *Value) MSISlice(optionalDefault ...[]map[string]interface{}) []map[string]interface{} { + if s, ok := v.data.([]map[string]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustMSISlice gets the value as a []map[string]interface{}. +// +// Panics if the object is not a []map[string]interface{}. +func (v *Value) MustMSISlice() []map[string]interface{} { + return v.data.([]map[string]interface{}) +} + +// IsMSI gets whether the object contained is a map[string]interface{} or not. +func (v *Value) IsMSI() bool { + _, ok := v.data.(map[string]interface{}) + return ok +} + +// IsMSISlice gets whether the object contained is a []map[string]interface{} or not. +func (v *Value) IsMSISlice() bool { + _, ok := v.data.([]map[string]interface{}) + return ok +} + +// EachMSI calls the specified callback for each object +// in the []map[string]interface{}. +// +// Panics if the object is the wrong type. +func (v *Value) EachMSI(callback func(int, map[string]interface{}) bool) *Value { + for index, val := range v.MustMSISlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereMSI uses the specified decider function to select items +// from the []map[string]interface{}. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereMSI(decider func(int, map[string]interface{}) bool) *Value { + var selected []map[string]interface{} + v.EachMSI(func(index int, val map[string]interface{}) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupMSI uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]map[string]interface{}. +func (v *Value) GroupMSI(grouper func(int, map[string]interface{}) string) *Value { + groups := make(map[string][]map[string]interface{}) + v.EachMSI(func(index int, val map[string]interface{}) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]map[string]interface{}, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceMSI uses the specified function to replace each map[string]interface{}s +// by iterating each item. The data in the returned result will be a +// []map[string]interface{} containing the replaced items. +func (v *Value) ReplaceMSI(replacer func(int, map[string]interface{}) map[string]interface{}) *Value { + arr := v.MustMSISlice() + replaced := make([]map[string]interface{}, len(arr)) + v.EachMSI(func(index int, val map[string]interface{}) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectMSI uses the specified collector function to collect a value +// for each of the map[string]interface{}s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectMSI(collector func(int, map[string]interface{}) interface{}) *Value { + arr := v.MustMSISlice() + collected := make([]interface{}, len(arr)) + v.EachMSI(func(index int, val map[string]interface{}) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + ObjxMap ((Map) and [](Map)) +*/ + +// ObjxMap gets the value as a (Map), returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) ObjxMap(optionalDefault ...(Map)) Map { + if s, ok := v.data.((Map)); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return New(nil) +} + +// MustObjxMap gets the value as a (Map). +// +// Panics if the object is not a (Map). +func (v *Value) MustObjxMap() Map { + return v.data.((Map)) +} + +// ObjxMapSlice gets the value as a [](Map), returns the optionalDefault +// value or nil if the value is not a [](Map). +func (v *Value) ObjxMapSlice(optionalDefault ...[](Map)) [](Map) { + if s, ok := v.data.([](Map)); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustObjxMapSlice gets the value as a [](Map). +// +// Panics if the object is not a [](Map). +func (v *Value) MustObjxMapSlice() [](Map) { + return v.data.([](Map)) +} + +// IsObjxMap gets whether the object contained is a (Map) or not. +func (v *Value) IsObjxMap() bool { + _, ok := v.data.((Map)) + return ok +} + +// IsObjxMapSlice gets whether the object contained is a [](Map) or not. +func (v *Value) IsObjxMapSlice() bool { + _, ok := v.data.([](Map)) + return ok +} + +// EachObjxMap calls the specified callback for each object +// in the [](Map). +// +// Panics if the object is the wrong type. +func (v *Value) EachObjxMap(callback func(int, Map) bool) *Value { + for index, val := range v.MustObjxMapSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereObjxMap uses the specified decider function to select items +// from the [](Map). The object contained in the result will contain +// only the selected items. +func (v *Value) WhereObjxMap(decider func(int, Map) bool) *Value { + var selected [](Map) + v.EachObjxMap(func(index int, val Map) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupObjxMap uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][](Map). +func (v *Value) GroupObjxMap(grouper func(int, Map) string) *Value { + groups := make(map[string][](Map)) + v.EachObjxMap(func(index int, val Map) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([](Map), 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceObjxMap uses the specified function to replace each (Map)s +// by iterating each item. The data in the returned result will be a +// [](Map) containing the replaced items. +func (v *Value) ReplaceObjxMap(replacer func(int, Map) Map) *Value { + arr := v.MustObjxMapSlice() + replaced := make([](Map), len(arr)) + v.EachObjxMap(func(index int, val Map) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectObjxMap uses the specified collector function to collect a value +// for each of the (Map)s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectObjxMap(collector func(int, Map) interface{}) *Value { + arr := v.MustObjxMapSlice() + collected := make([]interface{}, len(arr)) + v.EachObjxMap(func(index int, val Map) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Bool (bool and []bool) +*/ + +// Bool gets the value as a bool, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Bool(optionalDefault ...bool) bool { + if s, ok := v.data.(bool); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return false +} + +// MustBool gets the value as a bool. +// +// Panics if the object is not a bool. +func (v *Value) MustBool() bool { + return v.data.(bool) +} + +// BoolSlice gets the value as a []bool, returns the optionalDefault +// value or nil if the value is not a []bool. +func (v *Value) BoolSlice(optionalDefault ...[]bool) []bool { + if s, ok := v.data.([]bool); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustBoolSlice gets the value as a []bool. +// +// Panics if the object is not a []bool. +func (v *Value) MustBoolSlice() []bool { + return v.data.([]bool) +} + +// IsBool gets whether the object contained is a bool or not. +func (v *Value) IsBool() bool { + _, ok := v.data.(bool) + return ok +} + +// IsBoolSlice gets whether the object contained is a []bool or not. +func (v *Value) IsBoolSlice() bool { + _, ok := v.data.([]bool) + return ok +} + +// EachBool calls the specified callback for each object +// in the []bool. +// +// Panics if the object is the wrong type. +func (v *Value) EachBool(callback func(int, bool) bool) *Value { + for index, val := range v.MustBoolSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereBool uses the specified decider function to select items +// from the []bool. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereBool(decider func(int, bool) bool) *Value { + var selected []bool + v.EachBool(func(index int, val bool) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupBool uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]bool. +func (v *Value) GroupBool(grouper func(int, bool) string) *Value { + groups := make(map[string][]bool) + v.EachBool(func(index int, val bool) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]bool, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceBool uses the specified function to replace each bools +// by iterating each item. The data in the returned result will be a +// []bool containing the replaced items. +func (v *Value) ReplaceBool(replacer func(int, bool) bool) *Value { + arr := v.MustBoolSlice() + replaced := make([]bool, len(arr)) + v.EachBool(func(index int, val bool) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectBool uses the specified collector function to collect a value +// for each of the bools in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectBool(collector func(int, bool) interface{}) *Value { + arr := v.MustBoolSlice() + collected := make([]interface{}, len(arr)) + v.EachBool(func(index int, val bool) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Str (string and []string) +*/ + +// Str gets the value as a string, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Str(optionalDefault ...string) string { + if s, ok := v.data.(string); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return "" +} + +// MustStr gets the value as a string. +// +// Panics if the object is not a string. +func (v *Value) MustStr() string { + return v.data.(string) +} + +// StrSlice gets the value as a []string, returns the optionalDefault +// value or nil if the value is not a []string. +func (v *Value) StrSlice(optionalDefault ...[]string) []string { + if s, ok := v.data.([]string); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustStrSlice gets the value as a []string. +// +// Panics if the object is not a []string. +func (v *Value) MustStrSlice() []string { + return v.data.([]string) +} + +// IsStr gets whether the object contained is a string or not. +func (v *Value) IsStr() bool { + _, ok := v.data.(string) + return ok +} + +// IsStrSlice gets whether the object contained is a []string or not. +func (v *Value) IsStrSlice() bool { + _, ok := v.data.([]string) + return ok +} + +// EachStr calls the specified callback for each object +// in the []string. +// +// Panics if the object is the wrong type. +func (v *Value) EachStr(callback func(int, string) bool) *Value { + for index, val := range v.MustStrSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereStr uses the specified decider function to select items +// from the []string. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereStr(decider func(int, string) bool) *Value { + var selected []string + v.EachStr(func(index int, val string) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupStr uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]string. +func (v *Value) GroupStr(grouper func(int, string) string) *Value { + groups := make(map[string][]string) + v.EachStr(func(index int, val string) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]string, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceStr uses the specified function to replace each strings +// by iterating each item. The data in the returned result will be a +// []string containing the replaced items. +func (v *Value) ReplaceStr(replacer func(int, string) string) *Value { + arr := v.MustStrSlice() + replaced := make([]string, len(arr)) + v.EachStr(func(index int, val string) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectStr uses the specified collector function to collect a value +// for each of the strings in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectStr(collector func(int, string) interface{}) *Value { + arr := v.MustStrSlice() + collected := make([]interface{}, len(arr)) + v.EachStr(func(index int, val string) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int (int and []int) +*/ + +// Int gets the value as a int, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int(optionalDefault ...int) int { + if s, ok := v.data.(int); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt gets the value as a int. +// +// Panics if the object is not a int. +func (v *Value) MustInt() int { + return v.data.(int) +} + +// IntSlice gets the value as a []int, returns the optionalDefault +// value or nil if the value is not a []int. +func (v *Value) IntSlice(optionalDefault ...[]int) []int { + if s, ok := v.data.([]int); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustIntSlice gets the value as a []int. +// +// Panics if the object is not a []int. +func (v *Value) MustIntSlice() []int { + return v.data.([]int) +} + +// IsInt gets whether the object contained is a int or not. +func (v *Value) IsInt() bool { + _, ok := v.data.(int) + return ok +} + +// IsIntSlice gets whether the object contained is a []int or not. +func (v *Value) IsIntSlice() bool { + _, ok := v.data.([]int) + return ok +} + +// EachInt calls the specified callback for each object +// in the []int. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt(callback func(int, int) bool) *Value { + for index, val := range v.MustIntSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt uses the specified decider function to select items +// from the []int. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt(decider func(int, int) bool) *Value { + var selected []int + v.EachInt(func(index int, val int) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int. +func (v *Value) GroupInt(grouper func(int, int) string) *Value { + groups := make(map[string][]int) + v.EachInt(func(index int, val int) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt uses the specified function to replace each ints +// by iterating each item. The data in the returned result will be a +// []int containing the replaced items. +func (v *Value) ReplaceInt(replacer func(int, int) int) *Value { + arr := v.MustIntSlice() + replaced := make([]int, len(arr)) + v.EachInt(func(index int, val int) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt uses the specified collector function to collect a value +// for each of the ints in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt(collector func(int, int) interface{}) *Value { + arr := v.MustIntSlice() + collected := make([]interface{}, len(arr)) + v.EachInt(func(index int, val int) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int8 (int8 and []int8) +*/ + +// Int8 gets the value as a int8, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int8(optionalDefault ...int8) int8 { + if s, ok := v.data.(int8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt8 gets the value as a int8. +// +// Panics if the object is not a int8. +func (v *Value) MustInt8() int8 { + return v.data.(int8) +} + +// Int8Slice gets the value as a []int8, returns the optionalDefault +// value or nil if the value is not a []int8. +func (v *Value) Int8Slice(optionalDefault ...[]int8) []int8 { + if s, ok := v.data.([]int8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt8Slice gets the value as a []int8. +// +// Panics if the object is not a []int8. +func (v *Value) MustInt8Slice() []int8 { + return v.data.([]int8) +} + +// IsInt8 gets whether the object contained is a int8 or not. +func (v *Value) IsInt8() bool { + _, ok := v.data.(int8) + return ok +} + +// IsInt8Slice gets whether the object contained is a []int8 or not. +func (v *Value) IsInt8Slice() bool { + _, ok := v.data.([]int8) + return ok +} + +// EachInt8 calls the specified callback for each object +// in the []int8. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt8(callback func(int, int8) bool) *Value { + for index, val := range v.MustInt8Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt8 uses the specified decider function to select items +// from the []int8. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt8(decider func(int, int8) bool) *Value { + var selected []int8 + v.EachInt8(func(index int, val int8) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt8 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int8. +func (v *Value) GroupInt8(grouper func(int, int8) string) *Value { + groups := make(map[string][]int8) + v.EachInt8(func(index int, val int8) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int8, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt8 uses the specified function to replace each int8s +// by iterating each item. The data in the returned result will be a +// []int8 containing the replaced items. +func (v *Value) ReplaceInt8(replacer func(int, int8) int8) *Value { + arr := v.MustInt8Slice() + replaced := make([]int8, len(arr)) + v.EachInt8(func(index int, val int8) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt8 uses the specified collector function to collect a value +// for each of the int8s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt8(collector func(int, int8) interface{}) *Value { + arr := v.MustInt8Slice() + collected := make([]interface{}, len(arr)) + v.EachInt8(func(index int, val int8) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int16 (int16 and []int16) +*/ + +// Int16 gets the value as a int16, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int16(optionalDefault ...int16) int16 { + if s, ok := v.data.(int16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt16 gets the value as a int16. +// +// Panics if the object is not a int16. +func (v *Value) MustInt16() int16 { + return v.data.(int16) +} + +// Int16Slice gets the value as a []int16, returns the optionalDefault +// value or nil if the value is not a []int16. +func (v *Value) Int16Slice(optionalDefault ...[]int16) []int16 { + if s, ok := v.data.([]int16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt16Slice gets the value as a []int16. +// +// Panics if the object is not a []int16. +func (v *Value) MustInt16Slice() []int16 { + return v.data.([]int16) +} + +// IsInt16 gets whether the object contained is a int16 or not. +func (v *Value) IsInt16() bool { + _, ok := v.data.(int16) + return ok +} + +// IsInt16Slice gets whether the object contained is a []int16 or not. +func (v *Value) IsInt16Slice() bool { + _, ok := v.data.([]int16) + return ok +} + +// EachInt16 calls the specified callback for each object +// in the []int16. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt16(callback func(int, int16) bool) *Value { + for index, val := range v.MustInt16Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt16 uses the specified decider function to select items +// from the []int16. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt16(decider func(int, int16) bool) *Value { + var selected []int16 + v.EachInt16(func(index int, val int16) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt16 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int16. +func (v *Value) GroupInt16(grouper func(int, int16) string) *Value { + groups := make(map[string][]int16) + v.EachInt16(func(index int, val int16) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int16, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt16 uses the specified function to replace each int16s +// by iterating each item. The data in the returned result will be a +// []int16 containing the replaced items. +func (v *Value) ReplaceInt16(replacer func(int, int16) int16) *Value { + arr := v.MustInt16Slice() + replaced := make([]int16, len(arr)) + v.EachInt16(func(index int, val int16) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt16 uses the specified collector function to collect a value +// for each of the int16s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt16(collector func(int, int16) interface{}) *Value { + arr := v.MustInt16Slice() + collected := make([]interface{}, len(arr)) + v.EachInt16(func(index int, val int16) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int32 (int32 and []int32) +*/ + +// Int32 gets the value as a int32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int32(optionalDefault ...int32) int32 { + if s, ok := v.data.(int32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt32 gets the value as a int32. +// +// Panics if the object is not a int32. +func (v *Value) MustInt32() int32 { + return v.data.(int32) +} + +// Int32Slice gets the value as a []int32, returns the optionalDefault +// value or nil if the value is not a []int32. +func (v *Value) Int32Slice(optionalDefault ...[]int32) []int32 { + if s, ok := v.data.([]int32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt32Slice gets the value as a []int32. +// +// Panics if the object is not a []int32. +func (v *Value) MustInt32Slice() []int32 { + return v.data.([]int32) +} + +// IsInt32 gets whether the object contained is a int32 or not. +func (v *Value) IsInt32() bool { + _, ok := v.data.(int32) + return ok +} + +// IsInt32Slice gets whether the object contained is a []int32 or not. +func (v *Value) IsInt32Slice() bool { + _, ok := v.data.([]int32) + return ok +} + +// EachInt32 calls the specified callback for each object +// in the []int32. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt32(callback func(int, int32) bool) *Value { + for index, val := range v.MustInt32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt32 uses the specified decider function to select items +// from the []int32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt32(decider func(int, int32) bool) *Value { + var selected []int32 + v.EachInt32(func(index int, val int32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int32. +func (v *Value) GroupInt32(grouper func(int, int32) string) *Value { + groups := make(map[string][]int32) + v.EachInt32(func(index int, val int32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt32 uses the specified function to replace each int32s +// by iterating each item. The data in the returned result will be a +// []int32 containing the replaced items. +func (v *Value) ReplaceInt32(replacer func(int, int32) int32) *Value { + arr := v.MustInt32Slice() + replaced := make([]int32, len(arr)) + v.EachInt32(func(index int, val int32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt32 uses the specified collector function to collect a value +// for each of the int32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt32(collector func(int, int32) interface{}) *Value { + arr := v.MustInt32Slice() + collected := make([]interface{}, len(arr)) + v.EachInt32(func(index int, val int32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int64 (int64 and []int64) +*/ + +// Int64 gets the value as a int64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int64(optionalDefault ...int64) int64 { + if s, ok := v.data.(int64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt64 gets the value as a int64. +// +// Panics if the object is not a int64. +func (v *Value) MustInt64() int64 { + return v.data.(int64) +} + +// Int64Slice gets the value as a []int64, returns the optionalDefault +// value or nil if the value is not a []int64. +func (v *Value) Int64Slice(optionalDefault ...[]int64) []int64 { + if s, ok := v.data.([]int64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt64Slice gets the value as a []int64. +// +// Panics if the object is not a []int64. +func (v *Value) MustInt64Slice() []int64 { + return v.data.([]int64) +} + +// IsInt64 gets whether the object contained is a int64 or not. +func (v *Value) IsInt64() bool { + _, ok := v.data.(int64) + return ok +} + +// IsInt64Slice gets whether the object contained is a []int64 or not. +func (v *Value) IsInt64Slice() bool { + _, ok := v.data.([]int64) + return ok +} + +// EachInt64 calls the specified callback for each object +// in the []int64. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt64(callback func(int, int64) bool) *Value { + for index, val := range v.MustInt64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt64 uses the specified decider function to select items +// from the []int64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt64(decider func(int, int64) bool) *Value { + var selected []int64 + v.EachInt64(func(index int, val int64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int64. +func (v *Value) GroupInt64(grouper func(int, int64) string) *Value { + groups := make(map[string][]int64) + v.EachInt64(func(index int, val int64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt64 uses the specified function to replace each int64s +// by iterating each item. The data in the returned result will be a +// []int64 containing the replaced items. +func (v *Value) ReplaceInt64(replacer func(int, int64) int64) *Value { + arr := v.MustInt64Slice() + replaced := make([]int64, len(arr)) + v.EachInt64(func(index int, val int64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt64 uses the specified collector function to collect a value +// for each of the int64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt64(collector func(int, int64) interface{}) *Value { + arr := v.MustInt64Slice() + collected := make([]interface{}, len(arr)) + v.EachInt64(func(index int, val int64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint (uint and []uint) +*/ + +// Uint gets the value as a uint, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint(optionalDefault ...uint) uint { + if s, ok := v.data.(uint); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint gets the value as a uint. +// +// Panics if the object is not a uint. +func (v *Value) MustUint() uint { + return v.data.(uint) +} + +// UintSlice gets the value as a []uint, returns the optionalDefault +// value or nil if the value is not a []uint. +func (v *Value) UintSlice(optionalDefault ...[]uint) []uint { + if s, ok := v.data.([]uint); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUintSlice gets the value as a []uint. +// +// Panics if the object is not a []uint. +func (v *Value) MustUintSlice() []uint { + return v.data.([]uint) +} + +// IsUint gets whether the object contained is a uint or not. +func (v *Value) IsUint() bool { + _, ok := v.data.(uint) + return ok +} + +// IsUintSlice gets whether the object contained is a []uint or not. +func (v *Value) IsUintSlice() bool { + _, ok := v.data.([]uint) + return ok +} + +// EachUint calls the specified callback for each object +// in the []uint. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint(callback func(int, uint) bool) *Value { + for index, val := range v.MustUintSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint uses the specified decider function to select items +// from the []uint. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint(decider func(int, uint) bool) *Value { + var selected []uint + v.EachUint(func(index int, val uint) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint. +func (v *Value) GroupUint(grouper func(int, uint) string) *Value { + groups := make(map[string][]uint) + v.EachUint(func(index int, val uint) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint uses the specified function to replace each uints +// by iterating each item. The data in the returned result will be a +// []uint containing the replaced items. +func (v *Value) ReplaceUint(replacer func(int, uint) uint) *Value { + arr := v.MustUintSlice() + replaced := make([]uint, len(arr)) + v.EachUint(func(index int, val uint) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint uses the specified collector function to collect a value +// for each of the uints in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint(collector func(int, uint) interface{}) *Value { + arr := v.MustUintSlice() + collected := make([]interface{}, len(arr)) + v.EachUint(func(index int, val uint) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint8 (uint8 and []uint8) +*/ + +// Uint8 gets the value as a uint8, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint8(optionalDefault ...uint8) uint8 { + if s, ok := v.data.(uint8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint8 gets the value as a uint8. +// +// Panics if the object is not a uint8. +func (v *Value) MustUint8() uint8 { + return v.data.(uint8) +} + +// Uint8Slice gets the value as a []uint8, returns the optionalDefault +// value or nil if the value is not a []uint8. +func (v *Value) Uint8Slice(optionalDefault ...[]uint8) []uint8 { + if s, ok := v.data.([]uint8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint8Slice gets the value as a []uint8. +// +// Panics if the object is not a []uint8. +func (v *Value) MustUint8Slice() []uint8 { + return v.data.([]uint8) +} + +// IsUint8 gets whether the object contained is a uint8 or not. +func (v *Value) IsUint8() bool { + _, ok := v.data.(uint8) + return ok +} + +// IsUint8Slice gets whether the object contained is a []uint8 or not. +func (v *Value) IsUint8Slice() bool { + _, ok := v.data.([]uint8) + return ok +} + +// EachUint8 calls the specified callback for each object +// in the []uint8. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint8(callback func(int, uint8) bool) *Value { + for index, val := range v.MustUint8Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint8 uses the specified decider function to select items +// from the []uint8. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint8(decider func(int, uint8) bool) *Value { + var selected []uint8 + v.EachUint8(func(index int, val uint8) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint8 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint8. +func (v *Value) GroupUint8(grouper func(int, uint8) string) *Value { + groups := make(map[string][]uint8) + v.EachUint8(func(index int, val uint8) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint8, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint8 uses the specified function to replace each uint8s +// by iterating each item. The data in the returned result will be a +// []uint8 containing the replaced items. +func (v *Value) ReplaceUint8(replacer func(int, uint8) uint8) *Value { + arr := v.MustUint8Slice() + replaced := make([]uint8, len(arr)) + v.EachUint8(func(index int, val uint8) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint8 uses the specified collector function to collect a value +// for each of the uint8s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint8(collector func(int, uint8) interface{}) *Value { + arr := v.MustUint8Slice() + collected := make([]interface{}, len(arr)) + v.EachUint8(func(index int, val uint8) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint16 (uint16 and []uint16) +*/ + +// Uint16 gets the value as a uint16, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint16(optionalDefault ...uint16) uint16 { + if s, ok := v.data.(uint16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint16 gets the value as a uint16. +// +// Panics if the object is not a uint16. +func (v *Value) MustUint16() uint16 { + return v.data.(uint16) +} + +// Uint16Slice gets the value as a []uint16, returns the optionalDefault +// value or nil if the value is not a []uint16. +func (v *Value) Uint16Slice(optionalDefault ...[]uint16) []uint16 { + if s, ok := v.data.([]uint16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint16Slice gets the value as a []uint16. +// +// Panics if the object is not a []uint16. +func (v *Value) MustUint16Slice() []uint16 { + return v.data.([]uint16) +} + +// IsUint16 gets whether the object contained is a uint16 or not. +func (v *Value) IsUint16() bool { + _, ok := v.data.(uint16) + return ok +} + +// IsUint16Slice gets whether the object contained is a []uint16 or not. +func (v *Value) IsUint16Slice() bool { + _, ok := v.data.([]uint16) + return ok +} + +// EachUint16 calls the specified callback for each object +// in the []uint16. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint16(callback func(int, uint16) bool) *Value { + for index, val := range v.MustUint16Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint16 uses the specified decider function to select items +// from the []uint16. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint16(decider func(int, uint16) bool) *Value { + var selected []uint16 + v.EachUint16(func(index int, val uint16) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint16 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint16. +func (v *Value) GroupUint16(grouper func(int, uint16) string) *Value { + groups := make(map[string][]uint16) + v.EachUint16(func(index int, val uint16) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint16, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint16 uses the specified function to replace each uint16s +// by iterating each item. The data in the returned result will be a +// []uint16 containing the replaced items. +func (v *Value) ReplaceUint16(replacer func(int, uint16) uint16) *Value { + arr := v.MustUint16Slice() + replaced := make([]uint16, len(arr)) + v.EachUint16(func(index int, val uint16) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint16 uses the specified collector function to collect a value +// for each of the uint16s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint16(collector func(int, uint16) interface{}) *Value { + arr := v.MustUint16Slice() + collected := make([]interface{}, len(arr)) + v.EachUint16(func(index int, val uint16) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint32 (uint32 and []uint32) +*/ + +// Uint32 gets the value as a uint32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint32(optionalDefault ...uint32) uint32 { + if s, ok := v.data.(uint32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint32 gets the value as a uint32. +// +// Panics if the object is not a uint32. +func (v *Value) MustUint32() uint32 { + return v.data.(uint32) +} + +// Uint32Slice gets the value as a []uint32, returns the optionalDefault +// value or nil if the value is not a []uint32. +func (v *Value) Uint32Slice(optionalDefault ...[]uint32) []uint32 { + if s, ok := v.data.([]uint32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint32Slice gets the value as a []uint32. +// +// Panics if the object is not a []uint32. +func (v *Value) MustUint32Slice() []uint32 { + return v.data.([]uint32) +} + +// IsUint32 gets whether the object contained is a uint32 or not. +func (v *Value) IsUint32() bool { + _, ok := v.data.(uint32) + return ok +} + +// IsUint32Slice gets whether the object contained is a []uint32 or not. +func (v *Value) IsUint32Slice() bool { + _, ok := v.data.([]uint32) + return ok +} + +// EachUint32 calls the specified callback for each object +// in the []uint32. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint32(callback func(int, uint32) bool) *Value { + for index, val := range v.MustUint32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint32 uses the specified decider function to select items +// from the []uint32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint32(decider func(int, uint32) bool) *Value { + var selected []uint32 + v.EachUint32(func(index int, val uint32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint32. +func (v *Value) GroupUint32(grouper func(int, uint32) string) *Value { + groups := make(map[string][]uint32) + v.EachUint32(func(index int, val uint32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint32 uses the specified function to replace each uint32s +// by iterating each item. The data in the returned result will be a +// []uint32 containing the replaced items. +func (v *Value) ReplaceUint32(replacer func(int, uint32) uint32) *Value { + arr := v.MustUint32Slice() + replaced := make([]uint32, len(arr)) + v.EachUint32(func(index int, val uint32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint32 uses the specified collector function to collect a value +// for each of the uint32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint32(collector func(int, uint32) interface{}) *Value { + arr := v.MustUint32Slice() + collected := make([]interface{}, len(arr)) + v.EachUint32(func(index int, val uint32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint64 (uint64 and []uint64) +*/ + +// Uint64 gets the value as a uint64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint64(optionalDefault ...uint64) uint64 { + if s, ok := v.data.(uint64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint64 gets the value as a uint64. +// +// Panics if the object is not a uint64. +func (v *Value) MustUint64() uint64 { + return v.data.(uint64) +} + +// Uint64Slice gets the value as a []uint64, returns the optionalDefault +// value or nil if the value is not a []uint64. +func (v *Value) Uint64Slice(optionalDefault ...[]uint64) []uint64 { + if s, ok := v.data.([]uint64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint64Slice gets the value as a []uint64. +// +// Panics if the object is not a []uint64. +func (v *Value) MustUint64Slice() []uint64 { + return v.data.([]uint64) +} + +// IsUint64 gets whether the object contained is a uint64 or not. +func (v *Value) IsUint64() bool { + _, ok := v.data.(uint64) + return ok +} + +// IsUint64Slice gets whether the object contained is a []uint64 or not. +func (v *Value) IsUint64Slice() bool { + _, ok := v.data.([]uint64) + return ok +} + +// EachUint64 calls the specified callback for each object +// in the []uint64. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint64(callback func(int, uint64) bool) *Value { + for index, val := range v.MustUint64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint64 uses the specified decider function to select items +// from the []uint64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint64(decider func(int, uint64) bool) *Value { + var selected []uint64 + v.EachUint64(func(index int, val uint64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint64. +func (v *Value) GroupUint64(grouper func(int, uint64) string) *Value { + groups := make(map[string][]uint64) + v.EachUint64(func(index int, val uint64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint64 uses the specified function to replace each uint64s +// by iterating each item. The data in the returned result will be a +// []uint64 containing the replaced items. +func (v *Value) ReplaceUint64(replacer func(int, uint64) uint64) *Value { + arr := v.MustUint64Slice() + replaced := make([]uint64, len(arr)) + v.EachUint64(func(index int, val uint64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint64 uses the specified collector function to collect a value +// for each of the uint64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint64(collector func(int, uint64) interface{}) *Value { + arr := v.MustUint64Slice() + collected := make([]interface{}, len(arr)) + v.EachUint64(func(index int, val uint64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uintptr (uintptr and []uintptr) +*/ + +// Uintptr gets the value as a uintptr, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uintptr(optionalDefault ...uintptr) uintptr { + if s, ok := v.data.(uintptr); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUintptr gets the value as a uintptr. +// +// Panics if the object is not a uintptr. +func (v *Value) MustUintptr() uintptr { + return v.data.(uintptr) +} + +// UintptrSlice gets the value as a []uintptr, returns the optionalDefault +// value or nil if the value is not a []uintptr. +func (v *Value) UintptrSlice(optionalDefault ...[]uintptr) []uintptr { + if s, ok := v.data.([]uintptr); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUintptrSlice gets the value as a []uintptr. +// +// Panics if the object is not a []uintptr. +func (v *Value) MustUintptrSlice() []uintptr { + return v.data.([]uintptr) +} + +// IsUintptr gets whether the object contained is a uintptr or not. +func (v *Value) IsUintptr() bool { + _, ok := v.data.(uintptr) + return ok +} + +// IsUintptrSlice gets whether the object contained is a []uintptr or not. +func (v *Value) IsUintptrSlice() bool { + _, ok := v.data.([]uintptr) + return ok +} + +// EachUintptr calls the specified callback for each object +// in the []uintptr. +// +// Panics if the object is the wrong type. +func (v *Value) EachUintptr(callback func(int, uintptr) bool) *Value { + for index, val := range v.MustUintptrSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUintptr uses the specified decider function to select items +// from the []uintptr. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUintptr(decider func(int, uintptr) bool) *Value { + var selected []uintptr + v.EachUintptr(func(index int, val uintptr) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUintptr uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uintptr. +func (v *Value) GroupUintptr(grouper func(int, uintptr) string) *Value { + groups := make(map[string][]uintptr) + v.EachUintptr(func(index int, val uintptr) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uintptr, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUintptr uses the specified function to replace each uintptrs +// by iterating each item. The data in the returned result will be a +// []uintptr containing the replaced items. +func (v *Value) ReplaceUintptr(replacer func(int, uintptr) uintptr) *Value { + arr := v.MustUintptrSlice() + replaced := make([]uintptr, len(arr)) + v.EachUintptr(func(index int, val uintptr) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUintptr uses the specified collector function to collect a value +// for each of the uintptrs in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUintptr(collector func(int, uintptr) interface{}) *Value { + arr := v.MustUintptrSlice() + collected := make([]interface{}, len(arr)) + v.EachUintptr(func(index int, val uintptr) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Float32 (float32 and []float32) +*/ + +// Float32 gets the value as a float32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Float32(optionalDefault ...float32) float32 { + if s, ok := v.data.(float32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustFloat32 gets the value as a float32. +// +// Panics if the object is not a float32. +func (v *Value) MustFloat32() float32 { + return v.data.(float32) +} + +// Float32Slice gets the value as a []float32, returns the optionalDefault +// value or nil if the value is not a []float32. +func (v *Value) Float32Slice(optionalDefault ...[]float32) []float32 { + if s, ok := v.data.([]float32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustFloat32Slice gets the value as a []float32. +// +// Panics if the object is not a []float32. +func (v *Value) MustFloat32Slice() []float32 { + return v.data.([]float32) +} + +// IsFloat32 gets whether the object contained is a float32 or not. +func (v *Value) IsFloat32() bool { + _, ok := v.data.(float32) + return ok +} + +// IsFloat32Slice gets whether the object contained is a []float32 or not. +func (v *Value) IsFloat32Slice() bool { + _, ok := v.data.([]float32) + return ok +} + +// EachFloat32 calls the specified callback for each object +// in the []float32. +// +// Panics if the object is the wrong type. +func (v *Value) EachFloat32(callback func(int, float32) bool) *Value { + for index, val := range v.MustFloat32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereFloat32 uses the specified decider function to select items +// from the []float32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereFloat32(decider func(int, float32) bool) *Value { + var selected []float32 + v.EachFloat32(func(index int, val float32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupFloat32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]float32. +func (v *Value) GroupFloat32(grouper func(int, float32) string) *Value { + groups := make(map[string][]float32) + v.EachFloat32(func(index int, val float32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]float32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceFloat32 uses the specified function to replace each float32s +// by iterating each item. The data in the returned result will be a +// []float32 containing the replaced items. +func (v *Value) ReplaceFloat32(replacer func(int, float32) float32) *Value { + arr := v.MustFloat32Slice() + replaced := make([]float32, len(arr)) + v.EachFloat32(func(index int, val float32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectFloat32 uses the specified collector function to collect a value +// for each of the float32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectFloat32(collector func(int, float32) interface{}) *Value { + arr := v.MustFloat32Slice() + collected := make([]interface{}, len(arr)) + v.EachFloat32(func(index int, val float32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Float64 (float64 and []float64) +*/ + +// Float64 gets the value as a float64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Float64(optionalDefault ...float64) float64 { + if s, ok := v.data.(float64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustFloat64 gets the value as a float64. +// +// Panics if the object is not a float64. +func (v *Value) MustFloat64() float64 { + return v.data.(float64) +} + +// Float64Slice gets the value as a []float64, returns the optionalDefault +// value or nil if the value is not a []float64. +func (v *Value) Float64Slice(optionalDefault ...[]float64) []float64 { + if s, ok := v.data.([]float64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustFloat64Slice gets the value as a []float64. +// +// Panics if the object is not a []float64. +func (v *Value) MustFloat64Slice() []float64 { + return v.data.([]float64) +} + +// IsFloat64 gets whether the object contained is a float64 or not. +func (v *Value) IsFloat64() bool { + _, ok := v.data.(float64) + return ok +} + +// IsFloat64Slice gets whether the object contained is a []float64 or not. +func (v *Value) IsFloat64Slice() bool { + _, ok := v.data.([]float64) + return ok +} + +// EachFloat64 calls the specified callback for each object +// in the []float64. +// +// Panics if the object is the wrong type. +func (v *Value) EachFloat64(callback func(int, float64) bool) *Value { + for index, val := range v.MustFloat64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereFloat64 uses the specified decider function to select items +// from the []float64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereFloat64(decider func(int, float64) bool) *Value { + var selected []float64 + v.EachFloat64(func(index int, val float64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupFloat64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]float64. +func (v *Value) GroupFloat64(grouper func(int, float64) string) *Value { + groups := make(map[string][]float64) + v.EachFloat64(func(index int, val float64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]float64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceFloat64 uses the specified function to replace each float64s +// by iterating each item. The data in the returned result will be a +// []float64 containing the replaced items. +func (v *Value) ReplaceFloat64(replacer func(int, float64) float64) *Value { + arr := v.MustFloat64Slice() + replaced := make([]float64, len(arr)) + v.EachFloat64(func(index int, val float64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectFloat64 uses the specified collector function to collect a value +// for each of the float64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectFloat64(collector func(int, float64) interface{}) *Value { + arr := v.MustFloat64Slice() + collected := make([]interface{}, len(arr)) + v.EachFloat64(func(index int, val float64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Complex64 (complex64 and []complex64) +*/ + +// Complex64 gets the value as a complex64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Complex64(optionalDefault ...complex64) complex64 { + if s, ok := v.data.(complex64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustComplex64 gets the value as a complex64. +// +// Panics if the object is not a complex64. +func (v *Value) MustComplex64() complex64 { + return v.data.(complex64) +} + +// Complex64Slice gets the value as a []complex64, returns the optionalDefault +// value or nil if the value is not a []complex64. +func (v *Value) Complex64Slice(optionalDefault ...[]complex64) []complex64 { + if s, ok := v.data.([]complex64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustComplex64Slice gets the value as a []complex64. +// +// Panics if the object is not a []complex64. +func (v *Value) MustComplex64Slice() []complex64 { + return v.data.([]complex64) +} + +// IsComplex64 gets whether the object contained is a complex64 or not. +func (v *Value) IsComplex64() bool { + _, ok := v.data.(complex64) + return ok +} + +// IsComplex64Slice gets whether the object contained is a []complex64 or not. +func (v *Value) IsComplex64Slice() bool { + _, ok := v.data.([]complex64) + return ok +} + +// EachComplex64 calls the specified callback for each object +// in the []complex64. +// +// Panics if the object is the wrong type. +func (v *Value) EachComplex64(callback func(int, complex64) bool) *Value { + for index, val := range v.MustComplex64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereComplex64 uses the specified decider function to select items +// from the []complex64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereComplex64(decider func(int, complex64) bool) *Value { + var selected []complex64 + v.EachComplex64(func(index int, val complex64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupComplex64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]complex64. +func (v *Value) GroupComplex64(grouper func(int, complex64) string) *Value { + groups := make(map[string][]complex64) + v.EachComplex64(func(index int, val complex64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]complex64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceComplex64 uses the specified function to replace each complex64s +// by iterating each item. The data in the returned result will be a +// []complex64 containing the replaced items. +func (v *Value) ReplaceComplex64(replacer func(int, complex64) complex64) *Value { + arr := v.MustComplex64Slice() + replaced := make([]complex64, len(arr)) + v.EachComplex64(func(index int, val complex64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectComplex64 uses the specified collector function to collect a value +// for each of the complex64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectComplex64(collector func(int, complex64) interface{}) *Value { + arr := v.MustComplex64Slice() + collected := make([]interface{}, len(arr)) + v.EachComplex64(func(index int, val complex64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Complex128 (complex128 and []complex128) +*/ + +// Complex128 gets the value as a complex128, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Complex128(optionalDefault ...complex128) complex128 { + if s, ok := v.data.(complex128); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustComplex128 gets the value as a complex128. +// +// Panics if the object is not a complex128. +func (v *Value) MustComplex128() complex128 { + return v.data.(complex128) +} + +// Complex128Slice gets the value as a []complex128, returns the optionalDefault +// value or nil if the value is not a []complex128. +func (v *Value) Complex128Slice(optionalDefault ...[]complex128) []complex128 { + if s, ok := v.data.([]complex128); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustComplex128Slice gets the value as a []complex128. +// +// Panics if the object is not a []complex128. +func (v *Value) MustComplex128Slice() []complex128 { + return v.data.([]complex128) +} + +// IsComplex128 gets whether the object contained is a complex128 or not. +func (v *Value) IsComplex128() bool { + _, ok := v.data.(complex128) + return ok +} + +// IsComplex128Slice gets whether the object contained is a []complex128 or not. +func (v *Value) IsComplex128Slice() bool { + _, ok := v.data.([]complex128) + return ok +} + +// EachComplex128 calls the specified callback for each object +// in the []complex128. +// +// Panics if the object is the wrong type. +func (v *Value) EachComplex128(callback func(int, complex128) bool) *Value { + for index, val := range v.MustComplex128Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereComplex128 uses the specified decider function to select items +// from the []complex128. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereComplex128(decider func(int, complex128) bool) *Value { + var selected []complex128 + v.EachComplex128(func(index int, val complex128) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupComplex128 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]complex128. +func (v *Value) GroupComplex128(grouper func(int, complex128) string) *Value { + groups := make(map[string][]complex128) + v.EachComplex128(func(index int, val complex128) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]complex128, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceComplex128 uses the specified function to replace each complex128s +// by iterating each item. The data in the returned result will be a +// []complex128 containing the replaced items. +func (v *Value) ReplaceComplex128(replacer func(int, complex128) complex128) *Value { + arr := v.MustComplex128Slice() + replaced := make([]complex128, len(arr)) + v.EachComplex128(func(index int, val complex128) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectComplex128 uses the specified collector function to collect a value +// for each of the complex128s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectComplex128(collector func(int, complex128) interface{}) *Value { + arr := v.MustComplex128Slice() + collected := make([]interface{}, len(arr)) + v.EachComplex128(func(index int, val complex128) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} diff --git a/vendor/github.com/stretchr/objx/value.go b/vendor/github.com/stretchr/objx/value.go new file mode 100644 index 00000000..e4b4a143 --- /dev/null +++ b/vendor/github.com/stretchr/objx/value.go @@ -0,0 +1,53 @@ +package objx + +import ( + "fmt" + "strconv" +) + +// Value provides methods for extracting interface{} data in various +// types. +type Value struct { + // data contains the raw data being managed by this Value + data interface{} +} + +// Data returns the raw data contained by this Value +func (v *Value) Data() interface{} { + return v.data +} + +// String returns the value always as a string +func (v *Value) String() string { + switch { + case v.IsStr(): + return v.Str() + case v.IsBool(): + return strconv.FormatBool(v.Bool()) + case v.IsFloat32(): + return strconv.FormatFloat(float64(v.Float32()), 'f', -1, 32) + case v.IsFloat64(): + return strconv.FormatFloat(v.Float64(), 'f', -1, 64) + case v.IsInt(): + return strconv.FormatInt(int64(v.Int()), 10) + case v.IsInt8(): + return strconv.FormatInt(int64(v.Int8()), 10) + case v.IsInt16(): + return strconv.FormatInt(int64(v.Int16()), 10) + case v.IsInt32(): + return strconv.FormatInt(int64(v.Int32()), 10) + case v.IsInt64(): + return strconv.FormatInt(v.Int64(), 10) + case v.IsUint(): + return strconv.FormatUint(uint64(v.Uint()), 10) + case v.IsUint8(): + return strconv.FormatUint(uint64(v.Uint8()), 10) + case v.IsUint16(): + return strconv.FormatUint(uint64(v.Uint16()), 10) + case v.IsUint32(): + return strconv.FormatUint(uint64(v.Uint32()), 10) + case v.IsUint64(): + return strconv.FormatUint(v.Uint64(), 10) + } + return fmt.Sprintf("%#v", v.Data()) +} diff --git a/vendor/github.com/stretchr/testify/mock/doc.go b/vendor/github.com/stretchr/testify/mock/doc.go new file mode 100644 index 00000000..7324128e --- /dev/null +++ b/vendor/github.com/stretchr/testify/mock/doc.go @@ -0,0 +1,44 @@ +// Package mock provides a system by which it is possible to mock your objects +// and verify calls are happening as expected. +// +// Example Usage +// +// The mock package provides an object, Mock, that tracks activity on another object. It is usually +// embedded into a test object as shown below: +// +// type MyTestObject struct { +// // add a Mock object instance +// mock.Mock +// +// // other fields go here as normal +// } +// +// When implementing the methods of an interface, you wire your functions up +// to call the Mock.Called(args...) method, and return the appropriate values. +// +// For example, to mock a method that saves the name and age of a person and returns +// the year of their birth or an error, you might write this: +// +// func (o *MyTestObject) SavePersonDetails(firstname, lastname string, age int) (int, error) { +// args := o.Called(firstname, lastname, age) +// return args.Int(0), args.Error(1) +// } +// +// The Int, Error and Bool methods are examples of strongly typed getters that take the argument +// index position. Given this argument list: +// +// (12, true, "Something") +// +// You could read them out strongly typed like this: +// +// args.Int(0) +// args.Bool(1) +// args.String(2) +// +// For objects of your own type, use the generic Arguments.Get(index) method and make a type assertion: +// +// return args.Get(0).(*MyObject), args.Get(1).(*AnotherObjectOfMine) +// +// This may cause a panic if the object you are getting is nil (the type assertion will fail), in those +// cases you should check for nil first. +package mock diff --git a/vendor/github.com/stretchr/testify/mock/mock.go b/vendor/github.com/stretchr/testify/mock/mock.go new file mode 100644 index 00000000..cc4f642b --- /dev/null +++ b/vendor/github.com/stretchr/testify/mock/mock.go @@ -0,0 +1,885 @@ +package mock + +import ( + "errors" + "fmt" + "reflect" + "regexp" + "runtime" + "strings" + "sync" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/pmezard/go-difflib/difflib" + "github.com/stretchr/objx" + "github.com/stretchr/testify/assert" +) + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Logf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) + FailNow() +} + +/* + Call +*/ + +// Call represents a method call and is used for setting expectations, +// as well as recording activity. +type Call struct { + Parent *Mock + + // The name of the method that was or will be called. + Method string + + // Holds the arguments of the method. + Arguments Arguments + + // Holds the arguments that should be returned when + // this method is called. + ReturnArguments Arguments + + // Holds the caller info for the On() call + callerInfo []string + + // The number of times to return the return arguments when setting + // expectations. 0 means to always return the value. + Repeatability int + + // Amount of times this call has been called + totalCalls int + + // Call to this method can be optional + optional bool + + // Holds a channel that will be used to block the Return until it either + // receives a message or is closed. nil means it returns immediately. + WaitFor <-chan time.Time + + waitTime time.Duration + + // Holds a handler used to manipulate arguments content that are passed by + // reference. It's useful when mocking methods such as unmarshalers or + // decoders. + RunFn func(Arguments) +} + +func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call { + return &Call{ + Parent: parent, + Method: methodName, + Arguments: methodArguments, + ReturnArguments: make([]interface{}, 0), + callerInfo: callerInfo, + Repeatability: 0, + WaitFor: nil, + RunFn: nil, + } +} + +func (c *Call) lock() { + c.Parent.mutex.Lock() +} + +func (c *Call) unlock() { + c.Parent.mutex.Unlock() +} + +// Return specifies the return arguments for the expectation. +// +// Mock.On("DoSomething").Return(errors.New("failed")) +func (c *Call) Return(returnArguments ...interface{}) *Call { + c.lock() + defer c.unlock() + + c.ReturnArguments = returnArguments + + return c +} + +// Once indicates that that the mock should only return the value once. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() +func (c *Call) Once() *Call { + return c.Times(1) +} + +// Twice indicates that that the mock should only return the value twice. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() +func (c *Call) Twice() *Call { + return c.Times(2) +} + +// Times indicates that that the mock should only return the indicated number +// of times. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) +func (c *Call) Times(i int) *Call { + c.lock() + defer c.unlock() + c.Repeatability = i + return c +} + +// WaitUntil sets the channel that will block the mock's return until its closed +// or a message is received. +// +// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) +func (c *Call) WaitUntil(w <-chan time.Time) *Call { + c.lock() + defer c.unlock() + c.WaitFor = w + return c +} + +// After sets how long to block until the call returns +// +// Mock.On("MyMethod", arg1, arg2).After(time.Second) +func (c *Call) After(d time.Duration) *Call { + c.lock() + defer c.unlock() + c.waitTime = d + return c +} + +// Run sets a handler to be called before returning. It can be used when +// mocking a method such as unmarshalers that takes a pointer to a struct and +// sets properties in such struct +// +// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(func(args Arguments) { +// arg := args.Get(0).(*map[string]interface{}) +// arg["foo"] = "bar" +// }) +func (c *Call) Run(fn func(args Arguments)) *Call { + c.lock() + defer c.unlock() + c.RunFn = fn + return c +} + +// Maybe allows the method call to be optional. Not calling an optional method +// will not cause an error while asserting expectations +func (c *Call) Maybe() *Call { + c.lock() + defer c.unlock() + c.optional = true + return c +} + +// On chains a new expectation description onto the mocked interface. This +// allows syntax like. +// +// Mock. +// On("MyMethod", 1).Return(nil). +// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) +func (c *Call) On(methodName string, arguments ...interface{}) *Call { + return c.Parent.On(methodName, arguments...) +} + +// Mock is the workhorse used to track activity on another object. +// For an example of its usage, refer to the "Example Usage" section at the top +// of this document. +type Mock struct { + // Represents the calls that are expected of + // an object. + ExpectedCalls []*Call + + // Holds the calls that were made to this mocked object. + Calls []Call + + // test is An optional variable that holds the test struct, to be used when an + // invalid mock call was made. + test TestingT + + // TestData holds any data that might be useful for testing. Testify ignores + // this data completely allowing you to do whatever you like with it. + testData objx.Map + + mutex sync.Mutex +} + +// TestData holds any data that might be useful for testing. Testify ignores +// this data completely allowing you to do whatever you like with it. +func (m *Mock) TestData() objx.Map { + + if m.testData == nil { + m.testData = make(objx.Map) + } + + return m.testData +} + +/* + Setting expectations +*/ + +// Test sets the test struct variable of the mock object +func (m *Mock) Test(t TestingT) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.test = t +} + +// fail fails the current test with the given formatted format and args. +// In case that a test was defined, it uses the test APIs for failing a test, +// otherwise it uses panic. +func (m *Mock) fail(format string, args ...interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.test == nil { + panic(fmt.Sprintf(format, args...)) + } + m.test.Errorf(format, args...) + m.test.FailNow() +} + +// On starts a description of an expectation of the specified method +// being called. +// +// Mock.On("MyMethod", arg1, arg2) +func (m *Mock) On(methodName string, arguments ...interface{}) *Call { + for _, arg := range arguments { + if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { + panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) + } + } + + m.mutex.Lock() + defer m.mutex.Unlock() + c := newCall(m, methodName, assert.CallerInfo(), arguments...) + m.ExpectedCalls = append(m.ExpectedCalls, c) + return c +} + +// /* +// Recording and responding to activity +// */ + +func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { + for i, call := range m.ExpectedCalls { + if call.Method == method && call.Repeatability > -1 { + + _, diffCount := call.Arguments.Diff(arguments) + if diffCount == 0 { + return i, call + } + + } + } + return -1, nil +} + +func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) { + var diffCount int + var closestCall *Call + var err string + + for _, call := range m.expectedCalls() { + if call.Method == method { + + errInfo, tempDiffCount := call.Arguments.Diff(arguments) + if tempDiffCount < diffCount || diffCount == 0 { + diffCount = tempDiffCount + closestCall = call + err = errInfo + } + + } + } + + return closestCall, err +} + +func callString(method string, arguments Arguments, includeArgumentValues bool) string { + + var argValsString string + if includeArgumentValues { + var argVals []string + for argIndex, arg := range arguments { + argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg)) + } + argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t")) + } + + return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString) +} + +// Called tells the mock object that a method has been called, and gets an array +// of arguments to return. Panics if the call is unexpected (i.e. not preceded by +// appropriate .On .Return() calls) +// If Call.WaitFor is set, blocks until the channel is closed or receives a message. +func (m *Mock) Called(arguments ...interface{}) Arguments { + // get the calling function's name + pc, _, _, ok := runtime.Caller(1) + if !ok { + panic("Couldn't get the caller information") + } + functionPath := runtime.FuncForPC(pc).Name() + //Next four lines are required to use GCCGO function naming conventions. + //For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock + //uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree + //With GCCGO we need to remove interface information starting from pN
. + re := regexp.MustCompile("\\.pN\\d+_") + if re.MatchString(functionPath) { + functionPath = re.Split(functionPath, -1)[0] + } + parts := strings.Split(functionPath, ".") + functionName := parts[len(parts)-1] + return m.MethodCalled(functionName, arguments...) +} + +// MethodCalled tells the mock object that the given method has been called, and gets +// an array of arguments to return. Panics if the call is unexpected (i.e. not preceded +// by appropriate .On .Return() calls) +// If Call.WaitFor is set, blocks until the channel is closed or receives a message. +func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments { + m.mutex.Lock() + //TODO: could combine expected and closes in single loop + found, call := m.findExpectedCall(methodName, arguments...) + + if found < 0 { + // we have to fail here - because we don't know what to do + // as the return arguments. This is because: + // + // a) this is a totally unexpected call to this method, + // b) the arguments are not what was expected, or + // c) the developer has forgotten to add an accompanying On...Return pair. + + closestCall, mismatch := m.findClosestCall(methodName, arguments...) + m.mutex.Unlock() + + if closestCall != nil { + m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s", + callString(methodName, arguments, true), + callString(methodName, closestCall.Arguments, true), + diffArguments(closestCall.Arguments, arguments), + strings.TrimSpace(mismatch), + ) + } else { + m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo()) + } + } + + if call.Repeatability == 1 { + call.Repeatability = -1 + } else if call.Repeatability > 1 { + call.Repeatability-- + } + call.totalCalls++ + + // add the call + m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...)) + m.mutex.Unlock() + + // block if specified + if call.WaitFor != nil { + <-call.WaitFor + } else { + time.Sleep(call.waitTime) + } + + m.mutex.Lock() + runFn := call.RunFn + m.mutex.Unlock() + + if runFn != nil { + runFn(arguments) + } + + m.mutex.Lock() + returnArgs := call.ReturnArguments + m.mutex.Unlock() + + return returnArgs +} + +/* + Assertions +*/ + +type assertExpectationser interface { + AssertExpectations(TestingT) bool +} + +// AssertExpectationsForObjects asserts that everything specified with On and Return +// of the specified objects was in fact called as expected. +// +// Calls may have occurred in any order. +func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + for _, obj := range testObjects { + if m, ok := obj.(Mock); ok { + t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)") + obj = &m + } + m := obj.(assertExpectationser) + if !m.AssertExpectations(t) { + t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m)) + return false + } + } + return true +} + +// AssertExpectations asserts that everything specified with On and Return was +// in fact called as expected. Calls may have occurred in any order. +func (m *Mock) AssertExpectations(t TestingT) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + var somethingMissing bool + var failedExpectations int + + // iterate through each expectation + expectedCalls := m.expectedCalls() + for _, expectedCall := range expectedCalls { + if !expectedCall.optional && !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 { + somethingMissing = true + failedExpectations++ + t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo) + } else { + if expectedCall.Repeatability > 0 { + somethingMissing = true + failedExpectations++ + t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo) + } else { + t.Logf("PASS:\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) + } + } + } + + if somethingMissing { + t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo()) + } + + return !somethingMissing +} + +// AssertNumberOfCalls asserts that the method was called expectedCalls times. +func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + var actualCalls int + for _, call := range m.calls() { + if call.Method == methodName { + actualCalls++ + } + } + return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls)) +} + +// AssertCalled asserts that the method was called. +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + if !m.methodWasCalled(methodName, arguments) { + var calledWithArgs []string + for _, call := range m.calls() { + calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments)) + } + if len(calledWithArgs) == 0 { + return assert.Fail(t, "Should have called with given arguments", + fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments)) + } + return assert.Fail(t, "Should have called with given arguments", + fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n %v", methodName, arguments, strings.Join(calledWithArgs, "\n"))) + } + return true +} + +// AssertNotCalled asserts that the method was not called. +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + if m.methodWasCalled(methodName, arguments) { + return assert.Fail(t, "Should not have called with given arguments", + fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments)) + } + return true +} + +func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { + for _, call := range m.calls() { + if call.Method == methodName { + + _, differences := Arguments(expected).Diff(call.Arguments) + + if differences == 0 { + // found the expected call + return true + } + + } + } + // we didn't find the expected call + return false +} + +func (m *Mock) expectedCalls() []*Call { + return append([]*Call{}, m.ExpectedCalls...) +} + +func (m *Mock) calls() []Call { + return append([]Call{}, m.Calls...) +} + +/* + Arguments +*/ + +// Arguments holds an array of method arguments or return values. +type Arguments []interface{} + +const ( + // Anything is used in Diff and Assert when the argument being tested + // shouldn't be taken into consideration. + Anything = "mock.Anything" +) + +// AnythingOfTypeArgument is a string that contains the type of an argument +// for use when type checking. Used in Diff and Assert. +type AnythingOfTypeArgument string + +// AnythingOfType returns an AnythingOfTypeArgument object containing the +// name of the type to check for. Used in Diff and Assert. +// +// For example: +// Assert(t, AnythingOfType("string"), AnythingOfType("int")) +func AnythingOfType(t string) AnythingOfTypeArgument { + return AnythingOfTypeArgument(t) +} + +// argumentMatcher performs custom argument matching, returning whether or +// not the argument is matched by the expectation fixture function. +type argumentMatcher struct { + // fn is a function which accepts one argument, and returns a bool. + fn reflect.Value +} + +func (f argumentMatcher) Matches(argument interface{}) bool { + expectType := f.fn.Type().In(0) + expectTypeNilSupported := false + switch expectType.Kind() { + case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr: + expectTypeNilSupported = true + } + + argType := reflect.TypeOf(argument) + var arg reflect.Value + if argType == nil { + arg = reflect.New(expectType).Elem() + } else { + arg = reflect.ValueOf(argument) + } + + if argType == nil && !expectTypeNilSupported { + panic(errors.New("attempting to call matcher with nil for non-nil expected type")) + } + if argType == nil || argType.AssignableTo(expectType) { + result := f.fn.Call([]reflect.Value{arg}) + return result[0].Bool() + } + return false +} + +func (f argumentMatcher) String() string { + return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name()) +} + +// MatchedBy can be used to match a mock call based on only certain properties +// from a complex struct or some calculation. It takes a function that will be +// evaluated with the called argument and will return true when there's a match +// and false otherwise. +// +// Example: +// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) +// +// |fn|, must be a function accepting a single argument (of the expected type) +// which returns a bool. If |fn| doesn't match the required signature, +// MatchedBy() panics. +func MatchedBy(fn interface{}) argumentMatcher { + fnType := reflect.TypeOf(fn) + + if fnType.Kind() != reflect.Func { + panic(fmt.Sprintf("assert: arguments: %s is not a func", fn)) + } + if fnType.NumIn() != 1 { + panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn)) + } + if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool { + panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) + } + + return argumentMatcher{fn: reflect.ValueOf(fn)} +} + +// Get Returns the argument at the specified index. +func (args Arguments) Get(index int) interface{} { + if index+1 > len(args) { + panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args))) + } + return args[index] +} + +// Is gets whether the objects match the arguments specified. +func (args Arguments) Is(objects ...interface{}) bool { + for i, obj := range args { + if obj != objects[i] { + return false + } + } + return true +} + +// Diff gets a string describing the differences between the arguments +// and the specified objects. +// +// Returns the diff string and number of differences found. +func (args Arguments) Diff(objects []interface{}) (string, int) { + //TODO: could return string as error and nil for No difference + + var output = "\n" + var differences int + + var maxArgCount = len(args) + if len(objects) > maxArgCount { + maxArgCount = len(objects) + } + + for i := 0; i < maxArgCount; i++ { + var actual, expected interface{} + var actualFmt, expectedFmt string + + if len(objects) <= i { + actual = "(Missing)" + actualFmt = "(Missing)" + } else { + actual = objects[i] + actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + } + + if len(args) <= i { + expected = "(Missing)" + expectedFmt = "(Missing)" + } else { + expected = args[i] + expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + } + + if matcher, ok := expected.(argumentMatcher); ok { + if matcher.Matches(actual) { + output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) + } else { + differences++ + output = fmt.Sprintf("%s\t%d: PASS: %s not matched by %s\n", output, i, actualFmt, matcher) + } + } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { + + // type checking + if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) + } + + } else { + + // normal checking + + if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { + // match + output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) + } else { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) + } + } + + } + + if differences == 0 { + return "No differences.", differences + } + + return output, differences + +} + +// Assert compares the arguments with the specified objects and fails if +// they do not exactly match. +func (args Arguments) Assert(t TestingT, objects ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + // get the differences + diff, diffCount := args.Diff(objects) + + if diffCount == 0 { + return true + } + + // there are differences... report them... + t.Logf(diff) + t.Errorf("%sArguments do not match.", assert.CallerInfo()) + + return false + +} + +// String gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +// +// If no index is provided, String() returns a complete string representation +// of the arguments. +func (args Arguments) String(indexOrNil ...int) string { + + if len(indexOrNil) == 0 { + // normal String() method - return a string representation of the args + var argsStr []string + for _, arg := range args { + argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg))) + } + return strings.Join(argsStr, ",") + } else if len(indexOrNil) == 1 { + // Index has been specified - get the argument at that index + var index = indexOrNil[0] + var s string + var ok bool + if s, ok = args.Get(index).(string); !ok { + panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index))) + } + return s + } + + panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil))) + +} + +// Int gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Int(index int) int { + var s int + var ok bool + if s, ok = args.Get(index).(int); !ok { + panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +// Error gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Error(index int) error { + obj := args.Get(index) + var s error + var ok bool + if obj == nil { + return nil + } + if s, ok = obj.(error); !ok { + panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +// Bool gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Bool(index int) bool { + var s bool + var ok bool + if s, ok = args.Get(index).(bool); !ok { + panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +func diffArguments(expected Arguments, actual Arguments) string { + if len(expected) != len(actual) { + return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual)) + } + + for x := range expected { + if diffString := diff(expected[x], actual[x]); diffString != "" { + return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString) + } + } + + return "" +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice or array. Otherwise it returns an empty string. +func diff(expected interface{}, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { + return "" + } + + e := spewConfig.Sdump(expected) + a := spewConfig.Sdump(actual) + + diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(e), + B: difflib.SplitLines(a), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return diff +} + +var spewConfig = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, +} + +type tHelper interface { + Helper() +}