From 55361cea8c4f8594c60ff6f3cc59a71ba0d9fd93 Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Sat, 15 Sep 2018 19:16:35 +0200 Subject: [PATCH] Use Testify. (#630) --- acme/client_test.go | 166 +++++++------- acme/crypto_test.go | 63 ++---- acme/dns_challenge.go | 10 +- acme/dns_challenge_test.go | 382 +++++++++++++++++++++----------- acme/http_challenge_test.go | 43 ++-- acme/http_test.go | 125 ++++------- acme/tls_alpn_challenge_test.go | 81 ++++--- 7 files changed, 478 insertions(+), 392 deletions(-) diff --git a/acme/client_test.go b/acme/client_test.go index a84b6f30..4e39b003 100644 --- a/acme/client_test.go +++ b/acme/client_test.go @@ -8,18 +8,19 @@ import ( "net" "net/http" "net/http/httptest" - "strings" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewClient(t *testing.T) { keyBits := 32 // small value keeps test fast keyType := RSA2048 key, err := rsa.GenerateKey(rand.Reader, keyBits) - if err != nil { - t.Fatal("Could not generate test key:", err) - } + require.NoError(t, err, "Could not generate test key") + user := mockUser{ email: "test@test.com", regres: new(RegistrationResource), @@ -38,32 +39,19 @@ func TestNewClient(t *testing.T) { })) client, err := NewClient(ts.URL, user, keyType) - if err != nil { - t.Fatalf("Could not create client: %v", err) - } + require.NoError(t, err, "Could not create client") - if client.jws == nil { - t.Fatalf("Expected client.jws to not be nil") - } - if expected, actual := key, client.jws.privKey; actual != expected { - t.Errorf("Expected jws.privKey to be %p but was %p", expected, actual) - } - - if client.keyType != keyType { - t.Errorf("Expected keyType to be %s but was %s", keyType, client.keyType) - } - - if expected, actual := 2, len(client.solvers); actual != expected { - t.Fatalf("Expected %d solver(s), got %d", expected, actual) - } + require.NotNil(t, client.jws, "client.jws") + assert.Equal(t, key, client.jws.privKey, "client.jws.privKey") + assert.Equal(t, keyType, client.keyType, "client.keyType") + assert.Len(t, client.solvers, 2, "solvers") } func TestClientOptPort(t *testing.T) { keyBits := 32 // small value keeps test fast key, err := rsa.GenerateKey(rand.Reader, keyBits) - if err != nil { - t.Fatal("Could not generate test key:", err) - } + require.NoError(t, err, "Could not generate test key") + user := mockUser{ email: "test@test.com", regres: new(RegistrationResource), @@ -83,33 +71,26 @@ func TestClientOptPort(t *testing.T) { optPort := "1234" optHost := "" + client, err := NewClient(ts.URL, user, RSA2048) - if err != nil { - t.Fatalf("Could not create client: %v", err) - } + require.NoError(t, err, "Could not create client") + client.SetHTTPAddress(net.JoinHostPort(optHost, optPort)) - httpSolver, ok := client.solvers[HTTP01].(*httpChallenge) - if !ok { - t.Fatal("Expected http-01 solver to be httpChallenge type") - } - if httpSolver.jws != client.jws { - t.Error("Expected http-01 to have same jws as client") - } - if got := httpSolver.provider.(*HTTPProviderServer).port; got != optPort { - t.Errorf("Expected http-01 to have port %s but was %s", optPort, got) - } - if got := httpSolver.provider.(*HTTPProviderServer).iface; got != optHost { - t.Errorf("Expected http-01 to have iface %s but was %s", optHost, got) - } + require.IsType(t, &httpChallenge{}, client.solvers[HTTP01]) + httpSolver := client.solvers[HTTP01].(*httpChallenge) + + assert.Equal(t, httpSolver.jws, client.jws, "Expected http-01 to have same jws as client") + + httpProviderServer := httpSolver.provider.(*HTTPProviderServer) + assert.Equal(t, optPort, httpProviderServer.port, "port") + assert.Equal(t, optHost, httpProviderServer.iface, "iface") // test setting different host optHost = "127.0.0.1" client.SetHTTPAddress(net.JoinHostPort(optHost, optPort)) - if got := httpSolver.provider.(*HTTPProviderServer).iface; got != optHost { - t.Errorf("Expected http-01 to have iface %s but was %s", optHost, got) - } + assert.Equal(t, optHost, httpSolver.provider.(*HTTPProviderServer).iface, "iface") } func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) { @@ -121,7 +102,9 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) { })) defer ts.Close() - privKey, _ := rsa.GenerateKey(rand.Reader, 512) + privKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err) + j := &jws{privKey: privKey, getNonceURL: ts.URL} ch := make(chan bool) resultCh := make(chan bool) @@ -147,10 +130,12 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) { func TestValidate(t *testing.T) { var statuses []string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Minimal stub ACME server for validation. w.Header().Add("Replay-Nonce", "12345") w.Header().Add("Retry-After", "0") + switch r.Method { case http.MethodHead: case http.MethodPost: @@ -169,29 +154,57 @@ func TestValidate(t *testing.T) { })) defer ts.Close() - privKey, _ := rsa.GenerateKey(rand.Reader, 512) + privKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err) j := &jws{privKey: privKey, getNonceURL: ts.URL} - tsts := []struct { + testCases := []struct { name string statuses []string want string }{ - {"POST-unexpected", []string{"weird"}, "unexpected"}, - {"POST-valid", []string{"valid"}, ""}, - {"POST-invalid", []string{"invalid"}, "Error"}, - {"GET-unexpected", []string{"pending", "weird"}, "unexpected"}, - {"GET-valid", []string{"pending", "valid"}, ""}, - {"GET-invalid", []string{"pending", "invalid"}, "Error"}, + { + name: "POST-unexpected", + statuses: []string{"weird"}, + want: "unexpected", + }, + { + name: "POST-valid", + statuses: []string{"valid"}, + }, + { + name: "POST-invalid", + statuses: []string{"invalid"}, + want: "Error", + }, + { + name: "GET-unexpected", + statuses: []string{"pending", "weird"}, + want: "unexpected", + }, + { + name: "GET-valid", + statuses: []string{"pending", "valid"}, + }, + { + name: "GET-invalid", + statuses: []string{"pending", "invalid"}, + want: "Error", + }, } - for _, tst := range tsts { - statuses = tst.statuses - if err := validate(j, "example.com", ts.URL, challenge{Type: "http-01", Token: "token"}); err == nil && tst.want != "" { - t.Errorf("[%s] validate: got error %v, want something with %q", tst.name, err, tst.want) - } else if err != nil && !strings.Contains(err.Error(), tst.want) { - t.Errorf("[%s] validate: got error %v, want something with %q", tst.name, err, tst.want) - } + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + statuses = test.statuses + + err := validate(j, "example.com", ts.URL, challenge{Type: "http-01", Token: "token"}) + if test.want == "" { + assert.NoError(t, err) + } else { + assert.Error(t, err) + assert.Contains(t, err.Error(), test.want) + } + }) } } @@ -217,10 +230,10 @@ func TestGetChallenges(t *testing.T) { keyBits := 512 // small value keeps test fast keyType := RSA2048 + key, err := rsa.GenerateKey(rand.Reader, keyBits) - if err != nil { - t.Fatal("Could not generate test key:", err) - } + require.NoError(t, err, "Could not generate test key") + user := mockUser{ email: "test@test.com", regres: &RegistrationResource{URI: ts.URL}, @@ -228,23 +241,19 @@ func TestGetChallenges(t *testing.T) { } client, err := NewClient(ts.URL, user, keyType) - if err != nil { - t.Fatalf("Could not create client: %v", err) - } + require.NoError(t, err, "Could not create client") _, err = client.createOrderForIdentifiers([]string{"example.com"}) - if err != nil { - t.Fatal("Expecting \"Server did not provide next link to proceed\" error, got nil") - } + assert.NoError(t, err) } func TestResolveAccountByKey(t *testing.T) { keyBits := 512 keyType := RSA2048 + key, err := rsa.GenerateKey(rand.Reader, keyBits) - if err != nil { - t.Fatal("Could not generate test key:", err) - } + require.NoError(t, err, "Could not generate test key") + user := mockUser{ email: "test@test.com", regres: new(RegistrationResource), @@ -275,15 +284,12 @@ func TestResolveAccountByKey(t *testing.T) { })) client, err := NewClient(ts.URL+"/directory", user, keyType) - if err != nil { - t.Fatalf("Could not create client: %v", err) - } + require.NoError(t, err, "Could not create client") - if res, err := client.ResolveAccountByKey(); err != nil { - t.Fatalf("Unexpected error resolving account by key: %v", err) - } else if res.Body.Status != "valid" { - t.Errorf("Unexpected account status: %v", res.Body.Status) - } + res, err := client.ResolveAccountByKey() + require.NoError(t, err, "Unexpected error resolving account by key") + + assert.Equal(t, "valid", res.Body.Status, "Unexpected account status") } // writeJSONResponse marshals the body as JSON and writes it to the response. @@ -301,7 +307,7 @@ func writeJSONResponse(w http.ResponseWriter, body interface{}) { } // stubValidate is like validate, except it does nothing. -func stubValidate(j *jws, domain, uri string, chlng challenge) error { +func stubValidate(_ *jws, _, _ string, _ challenge) error { return nil } diff --git a/acme/crypto_test.go b/acme/crypto_test.go index f32611e1..242b2cb6 100644 --- a/acme/crypto_test.go +++ b/acme/crypto_test.go @@ -6,31 +6,26 @@ import ( "crypto/rsa" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGeneratePrivateKey(t *testing.T) { key, err := generatePrivateKey(RSA2048) - if err != nil { - t.Error("Error generating private key:", err) - } - if key == nil { - t.Error("Expected key to not be nil, but it was") - } + require.NoError(t, err, "Error generating private key") + + assert.NotNil(t, key) } func TestGenerateCSR(t *testing.T) { key, err := rsa.GenerateKey(rand.Reader, 512) - if err != nil { - t.Fatal("Error generating private key:", err) - } + require.NoError(t, err, "Error generating private key") csr, err := generateCsr(key, "fizz.buzz", nil, true) - if err != nil { - t.Error("Error generating CSR:", err) - } - if len(csr) == 0 { - t.Error("Expected CSR with data, but it was nil or length 0") - } + require.NoError(t, err, "Error generating CSR") + + assert.NotEmpty(t, csr) } func TestPEMEncode(t *testing.T) { @@ -38,50 +33,38 @@ func TestPEMEncode(t *testing.T) { reader := MockRandReader{b: buf} key, err := rsa.GenerateKey(reader, 32) - if err != nil { - t.Fatal("Error generating private key:", err) - } + require.NoError(t, err, "Error generating private key") data := pemEncode(key) - - if data == nil { - t.Fatal("Expected result to not be nil, but it was") - } - if len(data) != 127 { - t.Errorf("Expected PEM encoding to be length 127, but it was %d", len(data)) - } + require.NotNil(t, data) + assert.Len(t, data, 127) } func TestPEMCertExpiration(t *testing.T) { privKey, err := generatePrivateKey(RSA2048) - if err != nil { - t.Fatal("Error generating private key:", err) - } + require.NoError(t, err, "Error generating private key") expiration := time.Now().Add(365) expiration = expiration.Round(time.Second) certBytes, err := generateDerCert(privKey.(*rsa.PrivateKey), expiration, "test.com", nil) - if err != nil { - t.Fatal("Error generating cert:", err) - } + require.NoError(t, err, "Error generating cert") buf := bytes.NewBufferString("TestingRSAIsSoMuchFun") // Some random string should return an error. - if ctime, err := GetPEMCertExpiration(buf.Bytes()); err == nil { - t.Errorf("Expected getCertExpiration to return an error for garbage string but returned %v", ctime) - } + ctime, err := GetPEMCertExpiration(buf.Bytes()) + assert.Errorf(t, err, "Expected getCertExpiration to return an error for garbage string but returned %v", ctime) // A DER encoded certificate should return an error. - if _, err := GetPEMCertExpiration(certBytes); err == nil { - t.Errorf("Expected getCertExpiration to return an error for DER certificates but returned none.") - } + _, err = GetPEMCertExpiration(certBytes) + require.Error(t, err, "Expected getCertExpiration to return an error for DER certificates") // A PEM encoded certificate should work ok. pemCert := pemEncode(derCertificateBytes(certBytes)) - if ctime, err := GetPEMCertExpiration(pemCert); err != nil || !ctime.Equal(expiration.UTC()) { - t.Errorf("Expected getCertExpiration to return %v but returned %v. Error: %v", expiration, ctime, err) - } + ctime, err = GetPEMCertExpiration(pemCert) + require.NoError(t, err) + + assert.Equal(t, expiration.UTC(), ctime) } type MockRandReader struct { diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index 65427491..73956625 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -37,7 +37,7 @@ var defaultNameservers = []string{ "google-public-dns-b.google.com:53", } -// RecursiveNameservers are used to pre-check DNS propagations +// RecursiveNameservers are used to pre-check DNS propagation var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) // DNSTimeout is used to override the default DNS timeout of 10 seconds. @@ -235,7 +235,7 @@ func lookupNameservers(fqdn string) ([]string, error) { zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers) if err != nil { - return nil, fmt.Errorf("Could not determine the zone: %v", err) + return nil, fmt.Errorf("could not determine the zone: %v", err) } r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true) @@ -252,7 +252,7 @@ func lookupNameservers(fqdn string) ([]string, error) { if len(authoritativeNss) > 0 { return authoritativeNss, nil } - return nil, fmt.Errorf("Could not determine authoritative nameservers") + return nil, fmt.Errorf("could not determine authoritative nameservers") } // FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the @@ -274,7 +274,7 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { // Any response code other than NOERROR and NXDOMAIN is treated as error if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess { - return "", fmt.Errorf("Unexpected response code '%s' for %s", + return "", fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain) } @@ -297,7 +297,7 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { } } - return "", fmt.Errorf("Could not find the start of authority") + return "", fmt.Errorf("could not find the start of authority") } // dnsMsgContainsCNAME checks for a CNAME answer in msg diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go index e561739b..5d80bac7 100644 --- a/acme/dns_challenge_test.go +++ b/acme/dns_challenge_test.go @@ -7,103 +7,27 @@ import ( "net/http" "net/http/httptest" "os" - "reflect" "sort" - "strings" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var lookupNameserversTestsOK = []struct { - fqdn string - nss []string -}{ - {"books.google.com.ng.", - []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, - }, - {"www.google.com.", - []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, - }, - {"physics.georgetown.edu.", - []string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."}, - }, -} - -var lookupNameserversTestsErr = []struct { - fqdn string - error string -}{ - // invalid tld - {"_null.n0n0.", - "Could not determine the zone", - }, -} - -var findZoneByFqdnTests = []struct { - fqdn string - zone string -}{ - {"mail.google.com.", "google.com."}, // domain is a CNAME - {"foo.google.com.", "google.com."}, // domain is a non-existent subdomain - {"example.com.ac.", "ac."}, // domain is a eTLD - {"cross-zone-example.assets.sh.", "assets.sh."}, // domain is a cross-zone CNAME -} - -var checkAuthoritativeNssTests = []struct { - fqdn, value string - ns []string - ok bool -}{ - // TXT RR w/ expected value - {"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."}, - true, - }, - // No TXT RR - {"ns1.google.com.", "", []string{"ns2.google.com."}, - false, - }, -} - -var checkAuthoritativeNssTestsErr = []struct { - fqdn, value string - ns []string - error string -}{ - // TXT RR /w unexpected value - {"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."}, - "did not return the expected TXT record", - }, - // No TXT RR - {"ns1.google.com.", "fe01=", []string{"ns2.google.com."}, - "did not return the expected TXT record", - }, -} - -var checkResolvConfServersTests = []struct { - fixture string - expected []string - defaults []string -}{ - {"testdata/resolv.conf.1", []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"}, []string{"127.0.0.1:53"}}, - {"testdata/resolv.conf.nonexistant", []string{"127.0.0.1:53"}, []string{"127.0.0.1:53"}}, -} - func TestDNSValidServerResponse(t *testing.T) { PreCheckDNS = func(fqdn, value string) (bool, error) { return true, nil } - privKey, _ := rsa.GenerateKey(rand.Reader, 512) + + privKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Replay-Nonce", "12345") w.Write([]byte("{\"type\":\"dns01\",\"status\":\"valid\",\"uri\":\"http://some.url\",\"token\":\"http8\"}")) })) - manualProvider, _ := NewDNSProviderManual() - jws := &jws{privKey: privKey, getNonceURL: ts.URL} - solver := &dnsChallenge{jws: jws, validate: validate, provider: manualProvider} - clientChallenge := challenge{Type: "dns01", Status: "pending", URL: ts.URL, Token: "http8"} - go func() { time.Sleep(time.Second * 2) f := bufio.NewWriter(os.Stdout) @@ -111,90 +35,282 @@ func TestDNSValidServerResponse(t *testing.T) { f.WriteString("\n") }() - if err := solver.Solve(clientChallenge, "example.com"); err != nil { - t.Errorf("VALID: Expected Solve to return no error but the error was -> %v", err) + manualProvider, err := NewDNSProviderManual() + require.NoError(t, err) + + clientChallenge := challenge{Type: "dns01", Status: "pending", URL: ts.URL, Token: "http8"} + + solver := &dnsChallenge{ + jws: &jws{privKey: privKey, getNonceURL: ts.URL}, + validate: validate, + provider: manualProvider, } + + err = solver.Solve(clientChallenge, "example.com") + require.NoError(t, err) } func TestPreCheckDNS(t *testing.T) { ok, err := PreCheckDNS("acme-staging.api.letsencrypt.org", "fe01=") if err != nil || !ok { - t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org") + t.Errorf("PreCheckDNS failed for acme-staging.api.letsencrypt.org") } } func TestLookupNameserversOK(t *testing.T) { - for _, tt := range lookupNameserversTestsOK { - nss, err := lookupNameservers(tt.fqdn) - if err != nil { - t.Fatalf("#%s: got %q; want nil", tt.fqdn, err) - } + testCases := []struct { + fqdn string + nss []string + }{ + { + fqdn: "books.google.com.ng.", + nss: []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + { + fqdn: "www.google.com.", + nss: []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + { + fqdn: "physics.georgetown.edu.", + nss: []string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."}, + }, + } - sort.Strings(nss) - sort.Strings(tt.nss) + for _, test := range testCases { + test := test + t.Run(test.fqdn, func(t *testing.T) { + t.Parallel() - if !reflect.DeepEqual(nss, tt.nss) { - t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss) - } + nss, err := lookupNameservers(test.fqdn) + require.NoError(t, err) + + sort.Strings(nss) + sort.Strings(test.nss) + + assert.EqualValues(t, test.nss, nss) + }) } } func TestLookupNameserversErr(t *testing.T) { - for _, tt := range lookupNameserversTestsErr { - _, err := lookupNameservers(tt.fqdn) - if err == nil { - t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) - } + testCases := []struct { + desc string + fqdn string + error string + }{ + { + desc: "invalid tld", + fqdn: "_null.n0n0.", + error: "could not determine the zone", + }, + } - if !strings.Contains(err.Error(), tt.error) { - t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) - continue - } + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + _, err := lookupNameservers(test.fqdn) + require.Error(t, err) + assert.Contains(t, err.Error(), test.error) + }) } } func TestFindZoneByFqdn(t *testing.T) { - for _, tt := range findZoneByFqdnTests { - res, err := FindZoneByFqdn(tt.fqdn, RecursiveNameservers) - if err != nil { - t.Errorf("FindZoneByFqdn failed for %s: %v", tt.fqdn, err) - } - if res != tt.zone { - t.Errorf("%s: got %s; want %s", tt.fqdn, res, tt.zone) - } + testCases := []struct { + desc string + fqdn string + zone string + }{ + { + desc: "domain is a CNAME", + fqdn: "mail.google.com.", + zone: "google.com.", + }, + { + desc: "domain is a non-existent subdomain", + fqdn: "foo.google.com.", + zone: "google.com.", + }, + { + desc: "domain is a eTLD", + fqdn: "example.com.ac.", + zone: "ac.", + }, + { + desc: "domain is a cross-zone CNAME", + fqdn: "cross-zone-example.assets.sh.", + zone: "assets.sh.", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + zone, err := FindZoneByFqdn(test.fqdn, RecursiveNameservers) + require.NoError(t, err) + + assert.Equal(t, test.zone, zone) + }) } } func TestCheckAuthoritativeNss(t *testing.T) { - for _, tt := range checkAuthoritativeNssTests { - ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) - if ok != tt.ok { - t.Errorf("%s: got %t; want %t", tt.fqdn, ok, tt.ok) - } + testCases := []struct { + desc string + fqdn, value string + ns []string + expected bool + }{ + { + desc: "TXT RR w/ expected value", + fqdn: "8.8.8.8.asn.routeviews.org.", + value: "151698.8.8.024", + ns: []string{"asnums.routeviews.org."}, + expected: true, + }, + { + desc: "No TXT RR", + fqdn: "ns1.google.com.", + ns: []string{"ns2.google.com."}, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + ok, _ := checkAuthoritativeNss(test.fqdn, test.value, test.ns) + assert.Equal(t, test.expected, ok, test.fqdn) + }) } } func TestCheckAuthoritativeNssErr(t *testing.T) { - for _, tt := range checkAuthoritativeNssTestsErr { - _, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) - if err == nil { - t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) - } - if !strings.Contains(err.Error(), tt.error) { - t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) - continue - } + testCases := []struct { + desc string + fqdn, value string + ns []string + error string + }{ + { + desc: "TXT RR /w unexpected value", + fqdn: "8.8.8.8.asn.routeviews.org.", + value: "fe01=", + ns: []string{"asnums.routeviews.org."}, + error: "did not return the expected TXT record", + }, + { + desc: "No TXT RR", + fqdn: "ns1.google.com.", + value: "fe01=", + ns: []string{"ns2.google.com."}, + error: "did not return the expected TXT record", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + _, err := checkAuthoritativeNss(test.fqdn, test.value, test.ns) + require.Error(t, err) + assert.Contains(t, err.Error(), test.error) + }) } } func TestResolveConfServers(t *testing.T) { - for _, tt := range checkResolvConfServersTests { - result := getNameservers(tt.fixture, tt.defaults) + var testCases = []struct { + fixture string + expected []string + defaults []string + }{ + { + fixture: "testdata/resolv.conf.1", + defaults: []string{"127.0.0.1:53"}, + expected: []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"}, + }, + { + fixture: "testdata/resolv.conf.nonexistant", + defaults: []string{"127.0.0.1:53"}, + expected: []string{"127.0.0.1:53"}, + }, + } - sort.Strings(result) - sort.Strings(tt.expected) - if !reflect.DeepEqual(result, tt.expected) { - t.Errorf("#%s: expected %q; got %q", tt.fixture, tt.expected, result) - } + for _, test := range testCases { + t.Run(test.fixture, func(t *testing.T) { + + result := getNameservers(test.fixture, test.defaults) + + sort.Strings(result) + sort.Strings(test.expected) + + assert.Equal(t, test.expected, result) + }) + } +} + +func TestToFqdn(t *testing.T) { + testCases := []struct { + desc string + domain string + expected string + }{ + { + desc: "simple", + domain: "foo.bar.com", + expected: "foo.bar.com.", + }, + { + desc: "already FQDN", + domain: "foo.bar.com.", + expected: "foo.bar.com.", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + fqdn := ToFqdn(test.domain) + assert.Equal(t, test.expected, fqdn) + }) + } +} + +func TestUnFqdn(t *testing.T) { + testCases := []struct { + desc string + fqdn string + expected string + }{ + { + desc: "simple", + fqdn: "foo.bar.com.", + expected: "foo.bar.com", + }, + { + desc: "already domain", + fqdn: "foo.bar.com", + expected: "foo.bar.com", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + domain := UnFqdn(test.fqdn) + + assert.Equal(t, test.expected, domain) + }) } } diff --git a/acme/http_challenge_test.go b/acme/http_challenge_test.go index 10f92028..ba0e8bf7 100644 --- a/acme/http_challenge_test.go +++ b/acme/http_challenge_test.go @@ -4,14 +4,13 @@ import ( "crypto/rand" "crypto/rsa" "io/ioutil" - "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHTTPChallenge(t *testing.T) { - privKey, _ := rsa.GenerateKey(rand.Reader, 512) - j := &jws{privKey: privKey} - clientChallenge := challenge{Type: string(HTTP01), Token: "http1"} mockValidate := func(_ *jws, _, _ string, chlng challenge) error { uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token resp, err := httpGet(uri) @@ -36,22 +35,36 @@ func TestHTTPChallenge(t *testing.T) { return nil } - solver := &httpChallenge{jws: j, validate: mockValidate, provider: &HTTPProviderServer{port: "23457"}} - if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil { - t.Errorf("Solve error: got %v, want nil", err) + privKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err, "Could not generate test key") + + solver := &httpChallenge{ + jws: &jws{privKey: privKey}, + validate: mockValidate, + provider: &HTTPProviderServer{port: "23457"}, } + + clientChallenge := challenge{Type: string(HTTP01), Token: "http1"} + + err = solver.Solve(clientChallenge, "localhost:23457") + assert.NoError(t, err) } func TestHTTPChallengeInvalidPort(t *testing.T) { - privKey, _ := rsa.GenerateKey(rand.Reader, 128) - j := &jws{privKey: privKey} - clientChallenge := challenge{Type: string(HTTP01), Token: "http2"} - solver := &httpChallenge{jws: j, validate: stubValidate, provider: &HTTPProviderServer{port: "123456"}} + privKey, err := rsa.GenerateKey(rand.Reader, 128) + require.NoError(t, err, "Could not generate test key") - if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { - t.Errorf("Solve error: got %v, want error", err) - } else if want, want18 := "invalid port 123456", "123456: invalid port"; !strings.HasSuffix(err.Error(), want) && !strings.HasSuffix(err.Error(), want18) { - t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) + solver := &httpChallenge{ + jws: &jws{privKey: privKey}, + validate: stubValidate, + provider: &HTTPProviderServer{port: "123456"}, } + + clientChallenge := challenge{Type: string(HTTP01), Token: "http2"} + + err = solver.Solve(clientChallenge, "localhost:123456") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid port") + assert.Contains(t, err.Error(), "123456") } diff --git a/acme/http_test.go b/acme/http_test.go index 370e1245..2c5654be 100644 --- a/acme/http_test.go +++ b/acme/http_test.go @@ -7,9 +7,12 @@ import ( "os" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestHTTPHeadUserAgent(t *testing.T) { +func TestHTTPUserAgent(t *testing.T) { var ua, method string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ua = r.Header.Get("User-Agent") @@ -17,72 +20,43 @@ func TestHTTPHeadUserAgent(t *testing.T) { })) defer ts.Close() - _, err := httpHead(ts.URL) - if err != nil { - t.Fatal(err) + testCases := []struct { + method string + call func(u string) (resp *http.Response, err error) + }{ + { + method: http.MethodGet, + call: httpGet, + }, + { + method: http.MethodHead, + call: httpHead, + }, + { + method: http.MethodPost, + call: func(u string) (resp *http.Response, err error) { + return httpPost(u, "text/plain", strings.NewReader("falalalala")) + }, + }, } - if method != http.MethodHead { - t.Errorf("Expected method to be HEAD, got %s", method) - } - if !strings.Contains(ua, ourUserAgent) { - t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) - } -} + for _, test := range testCases { + t.Run(test.method, func(t *testing.T) { -func TestHTTPGetUserAgent(t *testing.T) { - var ua, method string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ua = r.Header.Get("User-Agent") - method = r.Method - })) - defer ts.Close() + _, err := test.call(ts.URL) + require.NoError(t, err) - res, err := httpGet(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - if method != http.MethodGet { - t.Errorf("Expected method to be GET, got %s", method) - } - if !strings.Contains(ua, ourUserAgent) { - t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) - } -} - -func TestHTTPPostUserAgent(t *testing.T) { - var ua, method string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ua = r.Header.Get("User-Agent") - method = r.Method - })) - defer ts.Close() - - res, err := httpPost(ts.URL, "text/plain", strings.NewReader("falalalala")) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - if method != http.MethodPost { - t.Errorf("Expected method to be POST, got %s", method) - } - if !strings.Contains(ua, ourUserAgent) { - t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + assert.Equal(t, test.method, method) + assert.Contains(t, ua, ourUserAgent, "User-Agent") + }) } } func TestUserAgent(t *testing.T) { ua := userAgent() - if !strings.Contains(ua, defaultGoUserAgent) { - t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua) - } - if !strings.Contains(ua, ourUserAgent) { - t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua) - } + assert.Contains(t, ua, defaultGoUserAgent) + assert.Contains(t, ua, ourUserAgent) if strings.HasSuffix(ua, " ") { t.Errorf("UA should not have trailing spaces; got '%s'", ua) } @@ -90,15 +64,10 @@ func TestUserAgent(t *testing.T) { // customize the UA by appending a value UserAgent = "MyApp/1.2.3" ua = userAgent() - if !strings.Contains(ua, defaultGoUserAgent) { - t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua) - } - if !strings.Contains(ua, ourUserAgent) { - t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua) - } - if !strings.Contains(ua, UserAgent) { - t.Errorf("Expected custom UA to contain %s, got '%s'", UserAgent, ua) - } + + assert.Contains(t, ua, defaultGoUserAgent) + assert.Contains(t, ua, ourUserAgent) + assert.Contains(t, ua, UserAgent) } // TestInitCertPool tests the http.go initCertPool function for customizing the @@ -185,25 +154,27 @@ p9BI7gVKtWSZYegicA== }, } - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - os.Setenv(caCertificatesEnvVar, tc.EnvVar) + for _, test := range testCases { + t.Run(test.Name, func(t *testing.T) { + os.Setenv(caCertificatesEnvVar, test.EnvVar) defer os.Setenv(caCertificatesEnvVar, "") defer func() { - if r := recover(); r == nil && tc.ExpectPanic { - t.Errorf("expected initCertPool() to panic, it did not") - } else if r != nil && !tc.ExpectPanic { - t.Errorf("expected initCertPool() to not panic, but it did") + r := recover() + + if test.ExpectPanic { + assert.NotNil(t, r, "expected initCertPool() to panic") + } else { + assert.Nil(t, r, "expected initCertPool() to not panic") } }() result := initCertPool() - if result == nil && !tc.ExpectNil { - t.Errorf("initCertPool() returned nil, expected non-nil") - } else if result != nil && tc.ExpectNil { - t.Errorf("initCertPool() returned non-nil, expected nil") + if test.ExpectNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) } }) } diff --git a/acme/tls_alpn_challenge_test.go b/acme/tls_alpn_challenge_test.go index d53834f2..4b090177 100644 --- a/acme/tls_alpn_challenge_test.go +++ b/acme/tls_alpn_challenge_test.go @@ -7,41 +7,29 @@ import ( "crypto/subtle" "crypto/tls" "encoding/asn1" - "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTLSALPNChallenge(t *testing.T) { domain := "localhost:23457" - privKey, _ := rsa.GenerateKey(rand.Reader, 512) - j := &jws{privKey: privKey} - clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"} + mockValidate := func(_ *jws, _, _ string, chlng challenge) error { conn, err := tls.Dial("tcp", domain, &tls.Config{ InsecureSkipVerify: true, }) - if err != nil { - t.Errorf("Expected to connect to challenge server without an error. %v", err) - } + assert.NoError(t, err, "Expected to connect to challenge server without an error") // Expect the server to only return one certificate connState := conn.ConnectionState() - if count := len(connState.PeerCertificates); count != 1 { - t.Errorf("Expected the challenge server to return exactly one certificate but got %d", count) - } + assert.Len(t, connState.PeerCertificates, 1, "Expected the challenge server to return exactly one certificate") remoteCert := connState.PeerCertificates[0] - if count := len(remoteCert.DNSNames); count != 1 { - t.Errorf("Expected the challenge certificate to have exactly one DNSNames entry but had %d", count) - } - - if remoteCert.DNSNames[0] != domain { - t.Errorf("Expected the challenge certificate DNSName to match %s but was %s", domain, remoteCert.DNSNames[0]) - } - - if len(remoteCert.Extensions) == 0 { - t.Error("Expected the challenge certificate to contain extensions, it contained nothing") - } + assert.Len(t, remoteCert.DNSNames, 1, "Expected the challenge certificate to have exactly one DNSNames entry") + assert.Equal(t, domain, remoteCert.DNSNames[0], "challenge certificate DNSName ") + assert.NotEmpty(t, remoteCert.Extensions, "Expected the challenge certificate to contain extensions") idx := -1 for i, ext := range remoteCert.Extensions { @@ -51,42 +39,51 @@ func TestTLSALPNChallenge(t *testing.T) { } } - if idx == -1 { - t.Fatal("Expected the challenge certificate to contain an extension with the id-pe-acmeIdentifier id, it did not") - } + require.NotEqual(t, -1, idx, "Expected the challenge certificate to contain an extension with the id-pe-acmeIdentifier id,") ext := remoteCert.Extensions[idx] - - if !ext.Critical { - t.Error("Expected the challenge certificate id-pe-acmeIdentifier extension to be marked as critical, it was not") - } + assert.True(t, ext.Critical, "Expected the challenge certificate id-pe-acmeIdentifier extension to be marked as critical") zBytes := sha256.Sum256([]byte(chlng.KeyAuthorization)) value, err := asn1.Marshal(zBytes[:sha256.Size]) - if err != nil { - t.Fatalf("Expected marshaling of the keyAuth to return no error, but was %v", err) - } + require.NoError(t, err, "Expected marshaling of the keyAuth to return no error") + if subtle.ConstantTimeCompare(value[:], ext.Value) != 1 { t.Errorf("Expected the challenge certificate id-pe-acmeIdentifier extension to contain the SHA-256 digest of the keyAuth, %v, but was %v", zBytes[:], ext.Value) } return nil } - solver := &tlsALPNChallenge{jws: j, validate: mockValidate, provider: &TLSALPNProviderServer{port: "23457"}} - if err := solver.Solve(clientChallenge, domain); err != nil { - t.Errorf("Solve error: got %v, want nil", err) + + privKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err, "Could not generate test key") + + solver := &tlsALPNChallenge{ + jws: &jws{privKey: privKey}, + validate: mockValidate, + provider: &TLSALPNProviderServer{port: "23457"}, } + + clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"} + + err = solver.Solve(clientChallenge, domain) + assert.NoError(t, err) } func TestTLSALPNChallengeInvalidPort(t *testing.T) { - privKey, _ := rsa.GenerateKey(rand.Reader, 128) - j := &jws{privKey: privKey} - clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"} - solver := &tlsALPNChallenge{jws: j, validate: stubValidate, provider: &TLSALPNProviderServer{port: "123456"}} + privKey, err := rsa.GenerateKey(rand.Reader, 128) + require.NoError(t, err, "Could not generate test key") - if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { - t.Errorf("Solve error: got %v, want error", err) - } else if want, want18 := "invalid port 123456", "123456: invalid port"; !strings.HasSuffix(err.Error(), want) && !strings.HasSuffix(err.Error(), want18) { - t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) + solver := &tlsALPNChallenge{ + jws: &jws{privKey: privKey}, + validate: stubValidate, + provider: &TLSALPNProviderServer{port: "123456"}, } + + clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"} + + err = solver.Solve(clientChallenge, "localhost:123456") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid port") + assert.Contains(t, err.Error(), "123456") }