diff --git a/go.mod b/go.mod index ff74db1b..0d80c679 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/stretchr/testify v1.6.1 github.com/transip/gotransip/v6 v6.2.0 github.com/urfave/cli v1.22.4 - github.com/vultr/govultr v0.5.0 + github.com/vultr/govultr/v2 v2.0.0 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20200822124328-c89045814202 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d diff --git a/go.sum b/go.sum index f80bfd87..370c4fde 100644 --- a/go.sum +++ b/go.sum @@ -191,8 +191,8 @@ github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVo github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= -github.com/hashicorp/go-retryablehttp v0.6.7 h1:8/CAEZt/+F7kR7GevNHulKkUjLht3CPmn7egmhieNKo= -github.com/hashicorp/go-retryablehttp v0.6.7/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= +github.com/hashicorp/go-retryablehttp v0.6.6 h1:HJunrbHTDDbBb/ay4kxa1n+dLmttUlnP3V9oNE4hmsM= +github.com/hashicorp/go-retryablehttp v0.6.6/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -341,8 +341,8 @@ github.com/uber-go/atomic v1.3.2 h1:Azu9lPBWRNKzYXSIwRfgRuDuS0YKsK4NFhiQv98gkxo= github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g= github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA= github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/vultr/govultr v0.5.0 h1:iQzYhzbokmpDARbvIkvTkoyS7WMH82zVTKAL1PZ4JOA= -github.com/vultr/govultr v0.5.0/go.mod h1:wZZXZbYbqyY1n3AldoeYNZK4Wnmmoq6dNFkvd5TV3ss= +github.com/vultr/govultr/v2 v2.0.0 h1:+lAtqfWy3g9VwL7tT2Fpyad8Vv4MxOhT/NU8O5dk+EQ= +github.com/vultr/govultr/v2 v2.0.0/go.mod h1:2PsEeg+gs3p/Fo5Pw8F9mv+DUBEOlrNZ8GmCTGmhOhs= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= diff --git a/providers/dns/vultr/vultr.go b/providers/dns/vultr/vultr.go index e2f6cc3a..4ea20dc7 100644 --- a/providers/dns/vultr/vultr.go +++ b/providers/dns/vultr/vultr.go @@ -4,17 +4,16 @@ package vultr import ( "context" - "crypto/tls" "errors" "fmt" "net/http" - "strconv" "strings" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" - "github.com/vultr/govultr" + "github.com/vultr/govultr/v2" + "golang.org/x/oauth2" ) // Environment variables names. @@ -36,6 +35,7 @@ type Config struct { PollingInterval time.Duration TTL int HTTPClient *http.Client + HTTPTimeout time.Duration } // NewDefaultConfig returns a default configuration for the DNSProvider. @@ -44,13 +44,7 @@ func NewDefaultConfig() *Config { TTL: env.GetOrDefaultInt(EnvTTL, dns01.DefaultTTL), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), - HTTPClient: &http.Client{ - Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30), - // from Vultr Client - Transport: &http.Transport{ - TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), - }, - }, + HTTPTimeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30), } } @@ -84,7 +78,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("vultr: credentials missing") } - client := govultr.NewClient(config.HTTPClient, config.APIKey) + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = &http.Client{ + Timeout: config.HTTPTimeout, + Transport: &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.APIKey}), + }, + } + } + + client := govultr.NewClient(httpClient) return &DNSProvider{client: client, config: config}, nil } @@ -102,7 +106,14 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { name := extractRecordName(fqdn, zoneDomain) - err = d.client.DNSRecord.Create(ctx, zoneDomain, "TXT", name, `"`+value+`"`, d.config.TTL, 0) + req := govultr.DomainRecordReq{ + Name: name, + Type: "TXT", + Data: `"` + value + `"`, + TTL: d.config.TTL, + Priority: func(v int) *int { return &v }(0), + } + _, err = d.client.DomainRecord.Create(ctx, zoneDomain, &req) if err != nil { return fmt.Errorf("vultr: API call failed: %w", err) } @@ -123,7 +134,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { var allErr []string for _, rec := range records { - err := d.client.DNSRecord.Delete(ctx, zoneDomain, strconv.Itoa(rec.RecordID)) + err := d.client.DomainRecord.Delete(ctx, zoneDomain, rec.ID) if err != nil { allErr = append(allErr, err.Error()) } @@ -143,43 +154,67 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { } func (d *DNSProvider) getHostedZone(ctx context.Context, domain string) (string, error) { - domains, err := d.client.DNSDomain.List(ctx) - if err != nil { - return "", fmt.Errorf("API call failed: %w", err) - } + listOptions := &govultr.ListOptions{PerPage: 25} - var hostedDomain govultr.DNSDomain - for _, dom := range domains { - if strings.HasSuffix(domain, dom.Domain) { - if len(dom.Domain) > len(hostedDomain.Domain) { + var hostedDomain govultr.Domain + + for { + domains, meta, err := d.client.Domain.List(ctx, listOptions) + if err != nil { + return "", fmt.Errorf("API call failed: %w", err) + } + + for _, dom := range domains { + if strings.HasSuffix(domain, dom.Domain) && len(dom.Domain) > len(hostedDomain.Domain) { hostedDomain = dom } } + + if domain == hostedDomain.Domain { + break + } + + if meta.Links.Next == "" { + break + } + + listOptions.Cursor = meta.Links.Next } + if hostedDomain.Domain == "" { - return "", fmt.Errorf("no matching Vultr domain found for domain %s", domain) + return "", fmt.Errorf("no matching domain found for domain %s", domain) } return hostedDomain.Domain, nil } -func (d *DNSProvider) findTxtRecords(ctx context.Context, domain, fqdn string) (string, []govultr.DNSRecord, error) { +func (d *DNSProvider) findTxtRecords(ctx context.Context, domain, fqdn string) (string, []govultr.DomainRecord, error) { zoneDomain, err := d.getHostedZone(ctx, domain) if err != nil { return "", nil, err } - var records []govultr.DNSRecord - result, err := d.client.DNSRecord.List(ctx, zoneDomain) - if err != nil { - return "", records, fmt.Errorf("API call has failed: %w", err) - } + listOptions := &govultr.ListOptions{PerPage: 25} - recordName := extractRecordName(fqdn, zoneDomain) - for _, record := range result { - if record.Type == "TXT" && record.Name == recordName { - records = append(records, record) + var records []govultr.DomainRecord + for { + result, meta, err := d.client.DomainRecord.List(ctx, zoneDomain, listOptions) + if err != nil { + return "", records, fmt.Errorf("API call has failed: %w", err) } + + recordName := extractRecordName(fqdn, zoneDomain) + for _, record := range result { + if record.Type == "TXT" && record.Name == recordName { + records = append(records, record) + } + } + + if meta.Links.Next == "" { + break + } + + listOptions.Cursor = meta.Links.Next } return zoneDomain, records, nil diff --git a/providers/dns/vultr/vultr_test.go b/providers/dns/vultr/vultr_test.go index 0144758e..75d64883 100644 --- a/providers/dns/vultr/vultr_test.go +++ b/providers/dns/vultr/vultr_test.go @@ -1,11 +1,19 @@ package vultr import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strconv" "testing" "time" "github.com/go-acme/lego/v4/platform/tester" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vultr/govultr/v2" ) const envDomain = envNamespace + "TEST_DOMAIN" @@ -90,6 +98,127 @@ func TestNewDNSProviderConfig(t *testing.T) { } } +func TestDNSProvider_getHostedZone(t *testing.T) { + testCases := []struct { + desc string + domain string + expected string + expectedPageCount int + }{ + { + desc: "exact match, in latest page", + domain: "test.my.example.com", + expected: "test.my.example.com", + expectedPageCount: 5, + }, + { + desc: "exact match, in the middle", + domain: "my.example.com", + expected: "my.example.com", + expectedPageCount: 3, + }, + { + desc: "exact match, first page", + domain: "example.com", + expected: "example.com", + expectedPageCount: 1, + }, + { + desc: "match on apex", + domain: "test.example.org", + expected: "example.org", + expectedPageCount: 5, + }, + { + desc: "match on parent", + domain: "test.my.example.net", + expected: "my.example.net", + expectedPageCount: 5, + }, + } + + domains := []govultr.Domain{{Domain: "example.com"}, {Domain: "example.org"}, {Domain: "example.net"}} + + for i := 0; i < 50; i++ { + domains = append(domains, govultr.Domain{Domain: fmt.Sprintf("my%02d.example.com", i)}) + } + + domains = append(domains, govultr.Domain{Domain: "my.example.com"}, govultr.Domain{Domain: "my.example.net"}) + + for i := 50; i < 100; i++ { + domains = append(domains, govultr.Domain{Domain: fmt.Sprintf("my%02d.example.com", i)}) + } + + domains = append(domains, govultr.Domain{Domain: "test.my.example.com"}) + + type domainsBase struct { + Domains []govultr.Domain `json:"domains"` + Meta *govultr.Meta `json:"meta"` + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client := govultr.NewClient(nil) + err := client.SetBaseURL(server.URL) + require.NoError(t, err) + + p := &DNSProvider{client: client} + + var pageCount int + + mux.HandleFunc("/v2/domains", func(rw http.ResponseWriter, req *http.Request) { + pageCount++ + + query := req.URL.Query() + cursor, _ := strconv.Atoi(query.Get("cursor")) + perPage, _ := strconv.Atoi(query.Get("per_page")) + + var next string + if len(domains)/perPage > cursor { + next = strconv.Itoa(cursor + 1) + } + + start := cursor * perPage + if len(domains) < start { + start = cursor * len(domains) + } + + end := (cursor + 1) * perPage + if len(domains) < end { + end = len(domains) + } + + db := domainsBase{ + Domains: domains[start:end], + Meta: &govultr.Meta{ + Total: len(domains), + Links: &govultr.Links{Next: next}, + }, + } + + err = json.NewEncoder(rw).Encode(db) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + zone, err := p.getHostedZone(context.Background(), test.domain) + require.NoError(t, err) + + assert.Equal(t, test.expected, zone) + assert.Equal(t, test.expectedPageCount, pageCount) + }) + } +} + func TestLivePresent(t *testing.T) { if !envTest.IsLiveTest() { t.Skip("skipping live test")