forked from TrueCloudLab/lego
Use Testify. (#630)
This commit is contained in:
parent
bba134ce87
commit
55361cea8c
7 changed files with 478 additions and 392 deletions
|
@ -8,18 +8,19 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewClient(t *testing.T) {
|
func TestNewClient(t *testing.T) {
|
||||||
keyBits := 32 // small value keeps test fast
|
keyBits := 32 // small value keeps test fast
|
||||||
keyType := RSA2048
|
keyType := RSA2048
|
||||||
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not generate test key")
|
||||||
t.Fatal("Could not generate test key:", err)
|
|
||||||
}
|
|
||||||
user := mockUser{
|
user := mockUser{
|
||||||
email: "test@test.com",
|
email: "test@test.com",
|
||||||
regres: new(RegistrationResource),
|
regres: new(RegistrationResource),
|
||||||
|
@ -38,32 +39,19 @@ func TestNewClient(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, err := NewClient(ts.URL, user, keyType)
|
client, err := NewClient(ts.URL, user, keyType)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not create client")
|
||||||
t.Fatalf("Could not create client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if client.jws == nil {
|
require.NotNil(t, client.jws, "client.jws")
|
||||||
t.Fatalf("Expected client.jws to not be nil")
|
assert.Equal(t, key, client.jws.privKey, "client.jws.privKey")
|
||||||
}
|
assert.Equal(t, keyType, client.keyType, "client.keyType")
|
||||||
if expected, actual := key, client.jws.privKey; actual != expected {
|
assert.Len(t, client.solvers, 2, "solvers")
|
||||||
t.Errorf("Expected jws.privKey to be %p but was %p", expected, actual)
|
|
||||||
}
|
|
||||||
|
|
||||||
if client.keyType != keyType {
|
|
||||||
t.Errorf("Expected keyType to be %s but was %s", keyType, client.keyType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if expected, actual := 2, len(client.solvers); actual != expected {
|
|
||||||
t.Fatalf("Expected %d solver(s), got %d", expected, actual)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientOptPort(t *testing.T) {
|
func TestClientOptPort(t *testing.T) {
|
||||||
keyBits := 32 // small value keeps test fast
|
keyBits := 32 // small value keeps test fast
|
||||||
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not generate test key")
|
||||||
t.Fatal("Could not generate test key:", err)
|
|
||||||
}
|
|
||||||
user := mockUser{
|
user := mockUser{
|
||||||
email: "test@test.com",
|
email: "test@test.com",
|
||||||
regres: new(RegistrationResource),
|
regres: new(RegistrationResource),
|
||||||
|
@ -83,33 +71,26 @@ func TestClientOptPort(t *testing.T) {
|
||||||
|
|
||||||
optPort := "1234"
|
optPort := "1234"
|
||||||
optHost := ""
|
optHost := ""
|
||||||
|
|
||||||
client, err := NewClient(ts.URL, user, RSA2048)
|
client, err := NewClient(ts.URL, user, RSA2048)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not create client")
|
||||||
t.Fatalf("Could not create client: %v", err)
|
|
||||||
}
|
|
||||||
client.SetHTTPAddress(net.JoinHostPort(optHost, optPort))
|
client.SetHTTPAddress(net.JoinHostPort(optHost, optPort))
|
||||||
|
|
||||||
httpSolver, ok := client.solvers[HTTP01].(*httpChallenge)
|
require.IsType(t, &httpChallenge{}, client.solvers[HTTP01])
|
||||||
if !ok {
|
httpSolver := client.solvers[HTTP01].(*httpChallenge)
|
||||||
t.Fatal("Expected http-01 solver to be httpChallenge type")
|
|
||||||
}
|
assert.Equal(t, httpSolver.jws, client.jws, "Expected http-01 to have same jws as client")
|
||||||
if httpSolver.jws != client.jws {
|
|
||||||
t.Error("Expected http-01 to have same jws as client")
|
httpProviderServer := httpSolver.provider.(*HTTPProviderServer)
|
||||||
}
|
assert.Equal(t, optPort, httpProviderServer.port, "port")
|
||||||
if got := httpSolver.provider.(*HTTPProviderServer).port; got != optPort {
|
assert.Equal(t, optHost, httpProviderServer.iface, "iface")
|
||||||
t.Errorf("Expected http-01 to have port %s but was %s", optPort, got)
|
|
||||||
}
|
|
||||||
if got := httpSolver.provider.(*HTTPProviderServer).iface; got != optHost {
|
|
||||||
t.Errorf("Expected http-01 to have iface %s but was %s", optHost, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
// test setting different host
|
// test setting different host
|
||||||
optHost = "127.0.0.1"
|
optHost = "127.0.0.1"
|
||||||
client.SetHTTPAddress(net.JoinHostPort(optHost, optPort))
|
client.SetHTTPAddress(net.JoinHostPort(optHost, optPort))
|
||||||
|
|
||||||
if got := httpSolver.provider.(*HTTPProviderServer).iface; got != optHost {
|
assert.Equal(t, optHost, httpSolver.provider.(*HTTPProviderServer).iface, "iface")
|
||||||
t.Errorf("Expected http-01 to have iface %s but was %s", optHost, got)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
|
func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
|
||||||
|
@ -121,7 +102,9 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 512)
|
privKey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
j := &jws{privKey: privKey, getNonceURL: ts.URL}
|
j := &jws{privKey: privKey, getNonceURL: ts.URL}
|
||||||
ch := make(chan bool)
|
ch := make(chan bool)
|
||||||
resultCh := make(chan bool)
|
resultCh := make(chan bool)
|
||||||
|
@ -147,10 +130,12 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
|
||||||
|
|
||||||
func TestValidate(t *testing.T) {
|
func TestValidate(t *testing.T) {
|
||||||
var statuses []string
|
var statuses []string
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Minimal stub ACME server for validation.
|
// Minimal stub ACME server for validation.
|
||||||
w.Header().Add("Replay-Nonce", "12345")
|
w.Header().Add("Replay-Nonce", "12345")
|
||||||
w.Header().Add("Retry-After", "0")
|
w.Header().Add("Retry-After", "0")
|
||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodHead:
|
case http.MethodHead:
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
|
@ -169,29 +154,57 @@ func TestValidate(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 512)
|
privKey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
|
require.NoError(t, err)
|
||||||
j := &jws{privKey: privKey, getNonceURL: ts.URL}
|
j := &jws{privKey: privKey, getNonceURL: ts.URL}
|
||||||
|
|
||||||
tsts := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
statuses []string
|
statuses []string
|
||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{"POST-unexpected", []string{"weird"}, "unexpected"},
|
{
|
||||||
{"POST-valid", []string{"valid"}, ""},
|
name: "POST-unexpected",
|
||||||
{"POST-invalid", []string{"invalid"}, "Error"},
|
statuses: []string{"weird"},
|
||||||
{"GET-unexpected", []string{"pending", "weird"}, "unexpected"},
|
want: "unexpected",
|
||||||
{"GET-valid", []string{"pending", "valid"}, ""},
|
},
|
||||||
{"GET-invalid", []string{"pending", "invalid"}, "Error"},
|
{
|
||||||
|
name: "POST-valid",
|
||||||
|
statuses: []string{"valid"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "POST-invalid",
|
||||||
|
statuses: []string{"invalid"},
|
||||||
|
want: "Error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GET-unexpected",
|
||||||
|
statuses: []string{"pending", "weird"},
|
||||||
|
want: "unexpected",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GET-valid",
|
||||||
|
statuses: []string{"pending", "valid"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GET-invalid",
|
||||||
|
statuses: []string{"pending", "invalid"},
|
||||||
|
want: "Error",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tst := range tsts {
|
for _, test := range testCases {
|
||||||
statuses = tst.statuses
|
t.Run(test.name, func(t *testing.T) {
|
||||||
if err := validate(j, "example.com", ts.URL, challenge{Type: "http-01", Token: "token"}); err == nil && tst.want != "" {
|
statuses = test.statuses
|
||||||
t.Errorf("[%s] validate: got error %v, want something with %q", tst.name, err, tst.want)
|
|
||||||
} else if err != nil && !strings.Contains(err.Error(), tst.want) {
|
err := validate(j, "example.com", ts.URL, challenge{Type: "http-01", Token: "token"})
|
||||||
t.Errorf("[%s] validate: got error %v, want something with %q", tst.name, err, tst.want)
|
if test.want == "" {
|
||||||
}
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), test.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,10 +230,10 @@ func TestGetChallenges(t *testing.T) {
|
||||||
|
|
||||||
keyBits := 512 // small value keeps test fast
|
keyBits := 512 // small value keeps test fast
|
||||||
keyType := RSA2048
|
keyType := RSA2048
|
||||||
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not generate test key")
|
||||||
t.Fatal("Could not generate test key:", err)
|
|
||||||
}
|
|
||||||
user := mockUser{
|
user := mockUser{
|
||||||
email: "test@test.com",
|
email: "test@test.com",
|
||||||
regres: &RegistrationResource{URI: ts.URL},
|
regres: &RegistrationResource{URI: ts.URL},
|
||||||
|
@ -228,23 +241,19 @@ func TestGetChallenges(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(ts.URL, user, keyType)
|
client, err := NewClient(ts.URL, user, keyType)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not create client")
|
||||||
t.Fatalf("Could not create client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.createOrderForIdentifiers([]string{"example.com"})
|
_, err = client.createOrderForIdentifiers([]string{"example.com"})
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Fatal("Expecting \"Server did not provide next link to proceed\" error, got nil")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveAccountByKey(t *testing.T) {
|
func TestResolveAccountByKey(t *testing.T) {
|
||||||
keyBits := 512
|
keyBits := 512
|
||||||
keyType := RSA2048
|
keyType := RSA2048
|
||||||
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
key, err := rsa.GenerateKey(rand.Reader, keyBits)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not generate test key")
|
||||||
t.Fatal("Could not generate test key:", err)
|
|
||||||
}
|
|
||||||
user := mockUser{
|
user := mockUser{
|
||||||
email: "test@test.com",
|
email: "test@test.com",
|
||||||
regres: new(RegistrationResource),
|
regres: new(RegistrationResource),
|
||||||
|
@ -275,15 +284,12 @@ func TestResolveAccountByKey(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client, err := NewClient(ts.URL+"/directory", user, keyType)
|
client, err := NewClient(ts.URL+"/directory", user, keyType)
|
||||||
if err != nil {
|
require.NoError(t, err, "Could not create client")
|
||||||
t.Fatalf("Could not create client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res, err := client.ResolveAccountByKey(); err != nil {
|
res, err := client.ResolveAccountByKey()
|
||||||
t.Fatalf("Unexpected error resolving account by key: %v", err)
|
require.NoError(t, err, "Unexpected error resolving account by key")
|
||||||
} else if res.Body.Status != "valid" {
|
|
||||||
t.Errorf("Unexpected account status: %v", res.Body.Status)
|
assert.Equal(t, "valid", res.Body.Status, "Unexpected account status")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeJSONResponse marshals the body as JSON and writes it to the response.
|
// writeJSONResponse marshals the body as JSON and writes it to the response.
|
||||||
|
@ -301,7 +307,7 @@ func writeJSONResponse(w http.ResponseWriter, body interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// stubValidate is like validate, except it does nothing.
|
// stubValidate is like validate, except it does nothing.
|
||||||
func stubValidate(j *jws, domain, uri string, chlng challenge) error {
|
func stubValidate(_ *jws, _, _ string, _ challenge) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,31 +6,26 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGeneratePrivateKey(t *testing.T) {
|
func TestGeneratePrivateKey(t *testing.T) {
|
||||||
key, err := generatePrivateKey(RSA2048)
|
key, err := generatePrivateKey(RSA2048)
|
||||||
if err != nil {
|
require.NoError(t, err, "Error generating private key")
|
||||||
t.Error("Error generating private key:", err)
|
|
||||||
}
|
assert.NotNil(t, key)
|
||||||
if key == nil {
|
|
||||||
t.Error("Expected key to not be nil, but it was")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateCSR(t *testing.T) {
|
func TestGenerateCSR(t *testing.T) {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 512)
|
key, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
if err != nil {
|
require.NoError(t, err, "Error generating private key")
|
||||||
t.Fatal("Error generating private key:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
csr, err := generateCsr(key, "fizz.buzz", nil, true)
|
csr, err := generateCsr(key, "fizz.buzz", nil, true)
|
||||||
if err != nil {
|
require.NoError(t, err, "Error generating CSR")
|
||||||
t.Error("Error generating CSR:", err)
|
|
||||||
}
|
assert.NotEmpty(t, csr)
|
||||||
if len(csr) == 0 {
|
|
||||||
t.Error("Expected CSR with data, but it was nil or length 0")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPEMEncode(t *testing.T) {
|
func TestPEMEncode(t *testing.T) {
|
||||||
|
@ -38,50 +33,38 @@ func TestPEMEncode(t *testing.T) {
|
||||||
|
|
||||||
reader := MockRandReader{b: buf}
|
reader := MockRandReader{b: buf}
|
||||||
key, err := rsa.GenerateKey(reader, 32)
|
key, err := rsa.GenerateKey(reader, 32)
|
||||||
if err != nil {
|
require.NoError(t, err, "Error generating private key")
|
||||||
t.Fatal("Error generating private key:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := pemEncode(key)
|
data := pemEncode(key)
|
||||||
|
require.NotNil(t, data)
|
||||||
if data == nil {
|
assert.Len(t, data, 127)
|
||||||
t.Fatal("Expected result to not be nil, but it was")
|
|
||||||
}
|
|
||||||
if len(data) != 127 {
|
|
||||||
t.Errorf("Expected PEM encoding to be length 127, but it was %d", len(data))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPEMCertExpiration(t *testing.T) {
|
func TestPEMCertExpiration(t *testing.T) {
|
||||||
privKey, err := generatePrivateKey(RSA2048)
|
privKey, err := generatePrivateKey(RSA2048)
|
||||||
if err != nil {
|
require.NoError(t, err, "Error generating private key")
|
||||||
t.Fatal("Error generating private key:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expiration := time.Now().Add(365)
|
expiration := time.Now().Add(365)
|
||||||
expiration = expiration.Round(time.Second)
|
expiration = expiration.Round(time.Second)
|
||||||
certBytes, err := generateDerCert(privKey.(*rsa.PrivateKey), expiration, "test.com", nil)
|
certBytes, err := generateDerCert(privKey.(*rsa.PrivateKey), expiration, "test.com", nil)
|
||||||
if err != nil {
|
require.NoError(t, err, "Error generating cert")
|
||||||
t.Fatal("Error generating cert:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := bytes.NewBufferString("TestingRSAIsSoMuchFun")
|
buf := bytes.NewBufferString("TestingRSAIsSoMuchFun")
|
||||||
|
|
||||||
// Some random string should return an error.
|
// Some random string should return an error.
|
||||||
if ctime, err := GetPEMCertExpiration(buf.Bytes()); err == nil {
|
ctime, err := GetPEMCertExpiration(buf.Bytes())
|
||||||
t.Errorf("Expected getCertExpiration to return an error for garbage string but returned %v", ctime)
|
assert.Errorf(t, err, "Expected getCertExpiration to return an error for garbage string but returned %v", ctime)
|
||||||
}
|
|
||||||
|
|
||||||
// A DER encoded certificate should return an error.
|
// A DER encoded certificate should return an error.
|
||||||
if _, err := GetPEMCertExpiration(certBytes); err == nil {
|
_, err = GetPEMCertExpiration(certBytes)
|
||||||
t.Errorf("Expected getCertExpiration to return an error for DER certificates but returned none.")
|
require.Error(t, err, "Expected getCertExpiration to return an error for DER certificates")
|
||||||
}
|
|
||||||
|
|
||||||
// A PEM encoded certificate should work ok.
|
// A PEM encoded certificate should work ok.
|
||||||
pemCert := pemEncode(derCertificateBytes(certBytes))
|
pemCert := pemEncode(derCertificateBytes(certBytes))
|
||||||
if ctime, err := GetPEMCertExpiration(pemCert); err != nil || !ctime.Equal(expiration.UTC()) {
|
ctime, err = GetPEMCertExpiration(pemCert)
|
||||||
t.Errorf("Expected getCertExpiration to return %v but returned %v. Error: %v", expiration, ctime, err)
|
require.NoError(t, err)
|
||||||
}
|
|
||||||
|
assert.Equal(t, expiration.UTC(), ctime)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockRandReader struct {
|
type MockRandReader struct {
|
||||||
|
|
|
@ -37,7 +37,7 @@ var defaultNameservers = []string{
|
||||||
"google-public-dns-b.google.com:53",
|
"google-public-dns-b.google.com:53",
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecursiveNameservers are used to pre-check DNS propagations
|
// RecursiveNameservers are used to pre-check DNS propagation
|
||||||
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
|
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
|
||||||
|
|
||||||
// DNSTimeout is used to override the default DNS timeout of 10 seconds.
|
// DNSTimeout is used to override the default DNS timeout of 10 seconds.
|
||||||
|
@ -235,7 +235,7 @@ func lookupNameservers(fqdn string) ([]string, error) {
|
||||||
|
|
||||||
zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
|
zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not determine the zone: %v", err)
|
return nil, fmt.Errorf("could not determine the zone: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
|
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
|
||||||
|
@ -252,7 +252,7 @@ func lookupNameservers(fqdn string) ([]string, error) {
|
||||||
if len(authoritativeNss) > 0 {
|
if len(authoritativeNss) > 0 {
|
||||||
return authoritativeNss, nil
|
return authoritativeNss, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("Could not determine authoritative nameservers")
|
return nil, fmt.Errorf("could not determine authoritative nameservers")
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the
|
// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the
|
||||||
|
@ -274,7 +274,7 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
|
||||||
|
|
||||||
// Any response code other than NOERROR and NXDOMAIN is treated as error
|
// Any response code other than NOERROR and NXDOMAIN is treated as error
|
||||||
if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess {
|
if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess {
|
||||||
return "", fmt.Errorf("Unexpected response code '%s' for %s",
|
return "", fmt.Errorf("unexpected response code '%s' for %s",
|
||||||
dns.RcodeToString[in.Rcode], domain)
|
dns.RcodeToString[in.Rcode], domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -297,7 +297,7 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("Could not find the start of authority")
|
return "", fmt.Errorf("could not find the start of authority")
|
||||||
}
|
}
|
||||||
|
|
||||||
// dnsMsgContainsCNAME checks for a CNAME answer in msg
|
// dnsMsgContainsCNAME checks for a CNAME answer in msg
|
||||||
|
|
|
@ -7,103 +7,27 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var lookupNameserversTestsOK = []struct {
|
|
||||||
fqdn string
|
|
||||||
nss []string
|
|
||||||
}{
|
|
||||||
{"books.google.com.ng.",
|
|
||||||
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
|
||||||
},
|
|
||||||
{"www.google.com.",
|
|
||||||
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
|
||||||
},
|
|
||||||
{"physics.georgetown.edu.",
|
|
||||||
[]string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var lookupNameserversTestsErr = []struct {
|
|
||||||
fqdn string
|
|
||||||
error string
|
|
||||||
}{
|
|
||||||
// invalid tld
|
|
||||||
{"_null.n0n0.",
|
|
||||||
"Could not determine the zone",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var findZoneByFqdnTests = []struct {
|
|
||||||
fqdn string
|
|
||||||
zone string
|
|
||||||
}{
|
|
||||||
{"mail.google.com.", "google.com."}, // domain is a CNAME
|
|
||||||
{"foo.google.com.", "google.com."}, // domain is a non-existent subdomain
|
|
||||||
{"example.com.ac.", "ac."}, // domain is a eTLD
|
|
||||||
{"cross-zone-example.assets.sh.", "assets.sh."}, // domain is a cross-zone CNAME
|
|
||||||
}
|
|
||||||
|
|
||||||
var checkAuthoritativeNssTests = []struct {
|
|
||||||
fqdn, value string
|
|
||||||
ns []string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
// TXT RR w/ expected value
|
|
||||||
{"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."},
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
// No TXT RR
|
|
||||||
{"ns1.google.com.", "", []string{"ns2.google.com."},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var checkAuthoritativeNssTestsErr = []struct {
|
|
||||||
fqdn, value string
|
|
||||||
ns []string
|
|
||||||
error string
|
|
||||||
}{
|
|
||||||
// TXT RR /w unexpected value
|
|
||||||
{"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."},
|
|
||||||
"did not return the expected TXT record",
|
|
||||||
},
|
|
||||||
// No TXT RR
|
|
||||||
{"ns1.google.com.", "fe01=", []string{"ns2.google.com."},
|
|
||||||
"did not return the expected TXT record",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var checkResolvConfServersTests = []struct {
|
|
||||||
fixture string
|
|
||||||
expected []string
|
|
||||||
defaults []string
|
|
||||||
}{
|
|
||||||
{"testdata/resolv.conf.1", []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"}, []string{"127.0.0.1:53"}},
|
|
||||||
{"testdata/resolv.conf.nonexistant", []string{"127.0.0.1:53"}, []string{"127.0.0.1:53"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSValidServerResponse(t *testing.T) {
|
func TestDNSValidServerResponse(t *testing.T) {
|
||||||
PreCheckDNS = func(fqdn, value string) (bool, error) {
|
PreCheckDNS = func(fqdn, value string) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 512)
|
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Add("Replay-Nonce", "12345")
|
w.Header().Add("Replay-Nonce", "12345")
|
||||||
w.Write([]byte("{\"type\":\"dns01\",\"status\":\"valid\",\"uri\":\"http://some.url\",\"token\":\"http8\"}"))
|
w.Write([]byte("{\"type\":\"dns01\",\"status\":\"valid\",\"uri\":\"http://some.url\",\"token\":\"http8\"}"))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
manualProvider, _ := NewDNSProviderManual()
|
|
||||||
jws := &jws{privKey: privKey, getNonceURL: ts.URL}
|
|
||||||
solver := &dnsChallenge{jws: jws, validate: validate, provider: manualProvider}
|
|
||||||
clientChallenge := challenge{Type: "dns01", Status: "pending", URL: ts.URL, Token: "http8"}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second * 2)
|
time.Sleep(time.Second * 2)
|
||||||
f := bufio.NewWriter(os.Stdout)
|
f := bufio.NewWriter(os.Stdout)
|
||||||
|
@ -111,90 +35,282 @@ func TestDNSValidServerResponse(t *testing.T) {
|
||||||
f.WriteString("\n")
|
f.WriteString("\n")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := solver.Solve(clientChallenge, "example.com"); err != nil {
|
manualProvider, err := NewDNSProviderManual()
|
||||||
t.Errorf("VALID: Expected Solve to return no error but the error was -> %v", err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientChallenge := challenge{Type: "dns01", Status: "pending", URL: ts.URL, Token: "http8"}
|
||||||
|
|
||||||
|
solver := &dnsChallenge{
|
||||||
|
jws: &jws{privKey: privKey, getNonceURL: ts.URL},
|
||||||
|
validate: validate,
|
||||||
|
provider: manualProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = solver.Solve(clientChallenge, "example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPreCheckDNS(t *testing.T) {
|
func TestPreCheckDNS(t *testing.T) {
|
||||||
ok, err := PreCheckDNS("acme-staging.api.letsencrypt.org", "fe01=")
|
ok, err := PreCheckDNS("acme-staging.api.letsencrypt.org", "fe01=")
|
||||||
if err != nil || !ok {
|
if err != nil || !ok {
|
||||||
t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
|
t.Errorf("PreCheckDNS failed for acme-staging.api.letsencrypt.org")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLookupNameserversOK(t *testing.T) {
|
func TestLookupNameserversOK(t *testing.T) {
|
||||||
for _, tt := range lookupNameserversTestsOK {
|
testCases := []struct {
|
||||||
nss, err := lookupNameservers(tt.fqdn)
|
fqdn string
|
||||||
if err != nil {
|
nss []string
|
||||||
t.Fatalf("#%s: got %q; want nil", tt.fqdn, err)
|
}{
|
||||||
}
|
{
|
||||||
|
fqdn: "books.google.com.ng.",
|
||||||
|
nss: []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fqdn: "www.google.com.",
|
||||||
|
nss: []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fqdn: "physics.georgetown.edu.",
|
||||||
|
nss: []string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
sort.Strings(nss)
|
for _, test := range testCases {
|
||||||
sort.Strings(tt.nss)
|
test := test
|
||||||
|
t.Run(test.fqdn, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
if !reflect.DeepEqual(nss, tt.nss) {
|
nss, err := lookupNameservers(test.fqdn)
|
||||||
t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss)
|
require.NoError(t, err)
|
||||||
}
|
|
||||||
|
sort.Strings(nss)
|
||||||
|
sort.Strings(test.nss)
|
||||||
|
|
||||||
|
assert.EqualValues(t, test.nss, nss)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLookupNameserversErr(t *testing.T) {
|
func TestLookupNameserversErr(t *testing.T) {
|
||||||
for _, tt := range lookupNameserversTestsErr {
|
testCases := []struct {
|
||||||
_, err := lookupNameservers(tt.fqdn)
|
desc string
|
||||||
if err == nil {
|
fqdn string
|
||||||
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
|
error string
|
||||||
}
|
}{
|
||||||
|
{
|
||||||
|
desc: "invalid tld",
|
||||||
|
fqdn: "_null.n0n0.",
|
||||||
|
error: "could not determine the zone",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if !strings.Contains(err.Error(), tt.error) {
|
for _, test := range testCases {
|
||||||
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
|
test := test
|
||||||
continue
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
}
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := lookupNameservers(test.fqdn)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), test.error)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindZoneByFqdn(t *testing.T) {
|
func TestFindZoneByFqdn(t *testing.T) {
|
||||||
for _, tt := range findZoneByFqdnTests {
|
testCases := []struct {
|
||||||
res, err := FindZoneByFqdn(tt.fqdn, RecursiveNameservers)
|
desc string
|
||||||
if err != nil {
|
fqdn string
|
||||||
t.Errorf("FindZoneByFqdn failed for %s: %v", tt.fqdn, err)
|
zone string
|
||||||
}
|
}{
|
||||||
if res != tt.zone {
|
{
|
||||||
t.Errorf("%s: got %s; want %s", tt.fqdn, res, tt.zone)
|
desc: "domain is a CNAME",
|
||||||
}
|
fqdn: "mail.google.com.",
|
||||||
|
zone: "google.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "domain is a non-existent subdomain",
|
||||||
|
fqdn: "foo.google.com.",
|
||||||
|
zone: "google.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "domain is a eTLD",
|
||||||
|
fqdn: "example.com.ac.",
|
||||||
|
zone: "ac.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "domain is a cross-zone CNAME",
|
||||||
|
fqdn: "cross-zone-example.assets.sh.",
|
||||||
|
zone: "assets.sh.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
zone, err := FindZoneByFqdn(test.fqdn, RecursiveNameservers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, test.zone, zone)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckAuthoritativeNss(t *testing.T) {
|
func TestCheckAuthoritativeNss(t *testing.T) {
|
||||||
for _, tt := range checkAuthoritativeNssTests {
|
testCases := []struct {
|
||||||
ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
|
desc string
|
||||||
if ok != tt.ok {
|
fqdn, value string
|
||||||
t.Errorf("%s: got %t; want %t", tt.fqdn, ok, tt.ok)
|
ns []string
|
||||||
}
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "TXT RR w/ expected value",
|
||||||
|
fqdn: "8.8.8.8.asn.routeviews.org.",
|
||||||
|
value: "151698.8.8.024",
|
||||||
|
ns: []string{"asnums.routeviews.org."},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "No TXT RR",
|
||||||
|
fqdn: "ns1.google.com.",
|
||||||
|
ns: []string{"ns2.google.com."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ok, _ := checkAuthoritativeNss(test.fqdn, test.value, test.ns)
|
||||||
|
assert.Equal(t, test.expected, ok, test.fqdn)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckAuthoritativeNssErr(t *testing.T) {
|
func TestCheckAuthoritativeNssErr(t *testing.T) {
|
||||||
for _, tt := range checkAuthoritativeNssTestsErr {
|
testCases := []struct {
|
||||||
_, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
|
desc string
|
||||||
if err == nil {
|
fqdn, value string
|
||||||
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
|
ns []string
|
||||||
}
|
error string
|
||||||
if !strings.Contains(err.Error(), tt.error) {
|
}{
|
||||||
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
|
{
|
||||||
continue
|
desc: "TXT RR /w unexpected value",
|
||||||
}
|
fqdn: "8.8.8.8.asn.routeviews.org.",
|
||||||
|
value: "fe01=",
|
||||||
|
ns: []string{"asnums.routeviews.org."},
|
||||||
|
error: "did not return the expected TXT record",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "No TXT RR",
|
||||||
|
fqdn: "ns1.google.com.",
|
||||||
|
value: "fe01=",
|
||||||
|
ns: []string{"ns2.google.com."},
|
||||||
|
error: "did not return the expected TXT record",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := checkAuthoritativeNss(test.fqdn, test.value, test.ns)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), test.error)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveConfServers(t *testing.T) {
|
func TestResolveConfServers(t *testing.T) {
|
||||||
for _, tt := range checkResolvConfServersTests {
|
var testCases = []struct {
|
||||||
result := getNameservers(tt.fixture, tt.defaults)
|
fixture string
|
||||||
|
expected []string
|
||||||
|
defaults []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
fixture: "testdata/resolv.conf.1",
|
||||||
|
defaults: []string{"127.0.0.1:53"},
|
||||||
|
expected: []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fixture: "testdata/resolv.conf.nonexistant",
|
||||||
|
defaults: []string{"127.0.0.1:53"},
|
||||||
|
expected: []string{"127.0.0.1:53"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
sort.Strings(result)
|
for _, test := range testCases {
|
||||||
sort.Strings(tt.expected)
|
t.Run(test.fixture, func(t *testing.T) {
|
||||||
if !reflect.DeepEqual(result, tt.expected) {
|
|
||||||
t.Errorf("#%s: expected %q; got %q", tt.fixture, tt.expected, result)
|
result := getNameservers(test.fixture, test.defaults)
|
||||||
}
|
|
||||||
|
sort.Strings(result)
|
||||||
|
sort.Strings(test.expected)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToFqdn(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
domain string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple",
|
||||||
|
domain: "foo.bar.com",
|
||||||
|
expected: "foo.bar.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "already FQDN",
|
||||||
|
domain: "foo.bar.com.",
|
||||||
|
expected: "foo.bar.com.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fqdn := ToFqdn(test.domain)
|
||||||
|
assert.Equal(t, test.expected, fqdn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnFqdn(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
fqdn string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple",
|
||||||
|
fqdn: "foo.bar.com.",
|
||||||
|
expected: "foo.bar.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "already domain",
|
||||||
|
fqdn: "foo.bar.com",
|
||||||
|
expected: "foo.bar.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
domain := UnFqdn(test.fqdn)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, domain)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,14 +4,13 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHTTPChallenge(t *testing.T) {
|
func TestHTTPChallenge(t *testing.T) {
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 512)
|
|
||||||
j := &jws{privKey: privKey}
|
|
||||||
clientChallenge := challenge{Type: string(HTTP01), Token: "http1"}
|
|
||||||
mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
|
mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
|
||||||
uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token
|
uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token
|
||||||
resp, err := httpGet(uri)
|
resp, err := httpGet(uri)
|
||||||
|
@ -36,22 +35,36 @@ func TestHTTPChallenge(t *testing.T) {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
solver := &httpChallenge{jws: j, validate: mockValidate, provider: &HTTPProviderServer{port: "23457"}}
|
|
||||||
|
|
||||||
if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil {
|
privKey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
t.Errorf("Solve error: got %v, want nil", err)
|
require.NoError(t, err, "Could not generate test key")
|
||||||
|
|
||||||
|
solver := &httpChallenge{
|
||||||
|
jws: &jws{privKey: privKey},
|
||||||
|
validate: mockValidate,
|
||||||
|
provider: &HTTPProviderServer{port: "23457"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientChallenge := challenge{Type: string(HTTP01), Token: "http1"}
|
||||||
|
|
||||||
|
err = solver.Solve(clientChallenge, "localhost:23457")
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPChallengeInvalidPort(t *testing.T) {
|
func TestHTTPChallengeInvalidPort(t *testing.T) {
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 128)
|
privKey, err := rsa.GenerateKey(rand.Reader, 128)
|
||||||
j := &jws{privKey: privKey}
|
require.NoError(t, err, "Could not generate test key")
|
||||||
clientChallenge := challenge{Type: string(HTTP01), Token: "http2"}
|
|
||||||
solver := &httpChallenge{jws: j, validate: stubValidate, provider: &HTTPProviderServer{port: "123456"}}
|
|
||||||
|
|
||||||
if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil {
|
solver := &httpChallenge{
|
||||||
t.Errorf("Solve error: got %v, want error", err)
|
jws: &jws{privKey: privKey},
|
||||||
} else if want, want18 := "invalid port 123456", "123456: invalid port"; !strings.HasSuffix(err.Error(), want) && !strings.HasSuffix(err.Error(), want18) {
|
validate: stubValidate,
|
||||||
t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want)
|
provider: &HTTPProviderServer{port: "123456"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientChallenge := challenge{Type: string(HTTP01), Token: "http2"}
|
||||||
|
|
||||||
|
err = solver.Solve(clientChallenge, "localhost:123456")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid port")
|
||||||
|
assert.Contains(t, err.Error(), "123456")
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,9 +7,12 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHTTPHeadUserAgent(t *testing.T) {
|
func TestHTTPUserAgent(t *testing.T) {
|
||||||
var ua, method string
|
var ua, method string
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ua = r.Header.Get("User-Agent")
|
ua = r.Header.Get("User-Agent")
|
||||||
|
@ -17,72 +20,43 @@ func TestHTTPHeadUserAgent(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
_, err := httpHead(ts.URL)
|
testCases := []struct {
|
||||||
if err != nil {
|
method string
|
||||||
t.Fatal(err)
|
call func(u string) (resp *http.Response, err error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
method: http.MethodGet,
|
||||||
|
call: httpGet,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
method: http.MethodHead,
|
||||||
|
call: httpHead,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
method: http.MethodPost,
|
||||||
|
call: func(u string) (resp *http.Response, err error) {
|
||||||
|
return httpPost(u, "text/plain", strings.NewReader("falalalala"))
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if method != http.MethodHead {
|
for _, test := range testCases {
|
||||||
t.Errorf("Expected method to be HEAD, got %s", method)
|
t.Run(test.method, func(t *testing.T) {
|
||||||
}
|
|
||||||
if !strings.Contains(ua, ourUserAgent) {
|
|
||||||
t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPGetUserAgent(t *testing.T) {
|
_, err := test.call(ts.URL)
|
||||||
var ua, method string
|
require.NoError(t, err)
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ua = r.Header.Get("User-Agent")
|
|
||||||
method = r.Method
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
res, err := httpGet(ts.URL)
|
assert.Equal(t, test.method, method)
|
||||||
if err != nil {
|
assert.Contains(t, ua, ourUserAgent, "User-Agent")
|
||||||
t.Fatal(err)
|
})
|
||||||
}
|
|
||||||
res.Body.Close()
|
|
||||||
|
|
||||||
if method != http.MethodGet {
|
|
||||||
t.Errorf("Expected method to be GET, got %s", method)
|
|
||||||
}
|
|
||||||
if !strings.Contains(ua, ourUserAgent) {
|
|
||||||
t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPPostUserAgent(t *testing.T) {
|
|
||||||
var ua, method string
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ua = r.Header.Get("User-Agent")
|
|
||||||
method = r.Method
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
res, err := httpPost(ts.URL, "text/plain", strings.NewReader("falalalala"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
res.Body.Close()
|
|
||||||
|
|
||||||
if method != http.MethodPost {
|
|
||||||
t.Errorf("Expected method to be POST, got %s", method)
|
|
||||||
}
|
|
||||||
if !strings.Contains(ua, ourUserAgent) {
|
|
||||||
t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserAgent(t *testing.T) {
|
func TestUserAgent(t *testing.T) {
|
||||||
ua := userAgent()
|
ua := userAgent()
|
||||||
|
|
||||||
if !strings.Contains(ua, defaultGoUserAgent) {
|
assert.Contains(t, ua, defaultGoUserAgent)
|
||||||
t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua)
|
assert.Contains(t, ua, ourUserAgent)
|
||||||
}
|
|
||||||
if !strings.Contains(ua, ourUserAgent) {
|
|
||||||
t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua)
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(ua, " ") {
|
if strings.HasSuffix(ua, " ") {
|
||||||
t.Errorf("UA should not have trailing spaces; got '%s'", ua)
|
t.Errorf("UA should not have trailing spaces; got '%s'", ua)
|
||||||
}
|
}
|
||||||
|
@ -90,15 +64,10 @@ func TestUserAgent(t *testing.T) {
|
||||||
// customize the UA by appending a value
|
// customize the UA by appending a value
|
||||||
UserAgent = "MyApp/1.2.3"
|
UserAgent = "MyApp/1.2.3"
|
||||||
ua = userAgent()
|
ua = userAgent()
|
||||||
if !strings.Contains(ua, defaultGoUserAgent) {
|
|
||||||
t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua)
|
assert.Contains(t, ua, defaultGoUserAgent)
|
||||||
}
|
assert.Contains(t, ua, ourUserAgent)
|
||||||
if !strings.Contains(ua, ourUserAgent) {
|
assert.Contains(t, ua, UserAgent)
|
||||||
t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua)
|
|
||||||
}
|
|
||||||
if !strings.Contains(ua, UserAgent) {
|
|
||||||
t.Errorf("Expected custom UA to contain %s, got '%s'", UserAgent, ua)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestInitCertPool tests the http.go initCertPool function for customizing the
|
// TestInitCertPool tests the http.go initCertPool function for customizing the
|
||||||
|
@ -185,25 +154,27 @@ p9BI7gVKtWSZYegicA==
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, test := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(test.Name, func(t *testing.T) {
|
||||||
os.Setenv(caCertificatesEnvVar, tc.EnvVar)
|
os.Setenv(caCertificatesEnvVar, test.EnvVar)
|
||||||
defer os.Setenv(caCertificatesEnvVar, "")
|
defer os.Setenv(caCertificatesEnvVar, "")
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r == nil && tc.ExpectPanic {
|
r := recover()
|
||||||
t.Errorf("expected initCertPool() to panic, it did not")
|
|
||||||
} else if r != nil && !tc.ExpectPanic {
|
if test.ExpectPanic {
|
||||||
t.Errorf("expected initCertPool() to not panic, but it did")
|
assert.NotNil(t, r, "expected initCertPool() to panic")
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, r, "expected initCertPool() to not panic")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result := initCertPool()
|
result := initCertPool()
|
||||||
|
|
||||||
if result == nil && !tc.ExpectNil {
|
if test.ExpectNil {
|
||||||
t.Errorf("initCertPool() returned nil, expected non-nil")
|
assert.Nil(t, result)
|
||||||
} else if result != nil && tc.ExpectNil {
|
} else {
|
||||||
t.Errorf("initCertPool() returned non-nil, expected nil")
|
assert.NotNil(t, result)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,41 +7,29 @@ import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTLSALPNChallenge(t *testing.T) {
|
func TestTLSALPNChallenge(t *testing.T) {
|
||||||
domain := "localhost:23457"
|
domain := "localhost:23457"
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 512)
|
|
||||||
j := &jws{privKey: privKey}
|
|
||||||
clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"}
|
|
||||||
mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
|
mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
|
||||||
conn, err := tls.Dial("tcp", domain, &tls.Config{
|
conn, err := tls.Dial("tcp", domain, &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
assert.NoError(t, err, "Expected to connect to challenge server without an error")
|
||||||
t.Errorf("Expected to connect to challenge server without an error. %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expect the server to only return one certificate
|
// Expect the server to only return one certificate
|
||||||
connState := conn.ConnectionState()
|
connState := conn.ConnectionState()
|
||||||
if count := len(connState.PeerCertificates); count != 1 {
|
assert.Len(t, connState.PeerCertificates, 1, "Expected the challenge server to return exactly one certificate")
|
||||||
t.Errorf("Expected the challenge server to return exactly one certificate but got %d", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteCert := connState.PeerCertificates[0]
|
remoteCert := connState.PeerCertificates[0]
|
||||||
if count := len(remoteCert.DNSNames); count != 1 {
|
assert.Len(t, remoteCert.DNSNames, 1, "Expected the challenge certificate to have exactly one DNSNames entry")
|
||||||
t.Errorf("Expected the challenge certificate to have exactly one DNSNames entry but had %d", count)
|
assert.Equal(t, domain, remoteCert.DNSNames[0], "challenge certificate DNSName ")
|
||||||
}
|
assert.NotEmpty(t, remoteCert.Extensions, "Expected the challenge certificate to contain extensions")
|
||||||
|
|
||||||
if remoteCert.DNSNames[0] != domain {
|
|
||||||
t.Errorf("Expected the challenge certificate DNSName to match %s but was %s", domain, remoteCert.DNSNames[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(remoteCert.Extensions) == 0 {
|
|
||||||
t.Error("Expected the challenge certificate to contain extensions, it contained nothing")
|
|
||||||
}
|
|
||||||
|
|
||||||
idx := -1
|
idx := -1
|
||||||
for i, ext := range remoteCert.Extensions {
|
for i, ext := range remoteCert.Extensions {
|
||||||
|
@ -51,42 +39,51 @@ func TestTLSALPNChallenge(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if idx == -1 {
|
require.NotEqual(t, -1, idx, "Expected the challenge certificate to contain an extension with the id-pe-acmeIdentifier id,")
|
||||||
t.Fatal("Expected the challenge certificate to contain an extension with the id-pe-acmeIdentifier id, it did not")
|
|
||||||
}
|
|
||||||
|
|
||||||
ext := remoteCert.Extensions[idx]
|
ext := remoteCert.Extensions[idx]
|
||||||
|
assert.True(t, ext.Critical, "Expected the challenge certificate id-pe-acmeIdentifier extension to be marked as critical")
|
||||||
if !ext.Critical {
|
|
||||||
t.Error("Expected the challenge certificate id-pe-acmeIdentifier extension to be marked as critical, it was not")
|
|
||||||
}
|
|
||||||
|
|
||||||
zBytes := sha256.Sum256([]byte(chlng.KeyAuthorization))
|
zBytes := sha256.Sum256([]byte(chlng.KeyAuthorization))
|
||||||
value, err := asn1.Marshal(zBytes[:sha256.Size])
|
value, err := asn1.Marshal(zBytes[:sha256.Size])
|
||||||
if err != nil {
|
require.NoError(t, err, "Expected marshaling of the keyAuth to return no error")
|
||||||
t.Fatalf("Expected marshaling of the keyAuth to return no error, but was %v", err)
|
|
||||||
}
|
|
||||||
if subtle.ConstantTimeCompare(value[:], ext.Value) != 1 {
|
if subtle.ConstantTimeCompare(value[:], ext.Value) != 1 {
|
||||||
t.Errorf("Expected the challenge certificate id-pe-acmeIdentifier extension to contain the SHA-256 digest of the keyAuth, %v, but was %v", zBytes[:], ext.Value)
|
t.Errorf("Expected the challenge certificate id-pe-acmeIdentifier extension to contain the SHA-256 digest of the keyAuth, %v, but was %v", zBytes[:], ext.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
solver := &tlsALPNChallenge{jws: j, validate: mockValidate, provider: &TLSALPNProviderServer{port: "23457"}}
|
|
||||||
if err := solver.Solve(clientChallenge, domain); err != nil {
|
privKey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
t.Errorf("Solve error: got %v, want nil", err)
|
require.NoError(t, err, "Could not generate test key")
|
||||||
|
|
||||||
|
solver := &tlsALPNChallenge{
|
||||||
|
jws: &jws{privKey: privKey},
|
||||||
|
validate: mockValidate,
|
||||||
|
provider: &TLSALPNProviderServer{port: "23457"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"}
|
||||||
|
|
||||||
|
err = solver.Solve(clientChallenge, domain)
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSALPNChallengeInvalidPort(t *testing.T) {
|
func TestTLSALPNChallengeInvalidPort(t *testing.T) {
|
||||||
privKey, _ := rsa.GenerateKey(rand.Reader, 128)
|
privKey, err := rsa.GenerateKey(rand.Reader, 128)
|
||||||
j := &jws{privKey: privKey}
|
require.NoError(t, err, "Could not generate test key")
|
||||||
clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"}
|
|
||||||
solver := &tlsALPNChallenge{jws: j, validate: stubValidate, provider: &TLSALPNProviderServer{port: "123456"}}
|
|
||||||
|
|
||||||
if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil {
|
solver := &tlsALPNChallenge{
|
||||||
t.Errorf("Solve error: got %v, want error", err)
|
jws: &jws{privKey: privKey},
|
||||||
} else if want, want18 := "invalid port 123456", "123456: invalid port"; !strings.HasSuffix(err.Error(), want) && !strings.HasSuffix(err.Error(), want18) {
|
validate: stubValidate,
|
||||||
t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want)
|
provider: &TLSALPNProviderServer{port: "123456"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientChallenge := challenge{Type: string(TLSALPN01), Token: "tlsalpn1"}
|
||||||
|
|
||||||
|
err = solver.Solve(clientChallenge, "localhost:123456")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid port")
|
||||||
|
assert.Contains(t, err.Error(), "123456")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue