forked from TrueCloudLab/lego
383 lines
9.3 KiB
Go
383 lines
9.3 KiB
Go
package internal
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
const (
|
|
ns1 = "ns.checkdomain.de"
|
|
ns2 = "ns2.checkdomain.de"
|
|
)
|
|
|
|
// DefaultEndpoint the default API endpoint.
|
|
const DefaultEndpoint = "https://api.checkdomain.de"
|
|
|
|
const domainNotFound = -1
|
|
|
|
// max page limit that the checkdomain api allows.
|
|
const maxLimit = 100
|
|
|
|
// max integer value.
|
|
const maxInt = int((^uint(0)) >> 1)
|
|
|
|
// Client the Autodns API client.
|
|
type Client struct {
|
|
domainIDMapping map[string]int
|
|
domainIDMu sync.Mutex
|
|
|
|
BaseURL *url.URL
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewClient creates a new Client.
|
|
func NewClient(hc *http.Client) *Client {
|
|
baseURL, _ := url.Parse(DefaultEndpoint)
|
|
|
|
if hc == nil {
|
|
hc = &http.Client{Timeout: 10 * time.Second}
|
|
}
|
|
|
|
return &Client{
|
|
BaseURL: baseURL,
|
|
httpClient: hc,
|
|
domainIDMapping: make(map[string]int),
|
|
}
|
|
}
|
|
|
|
func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) {
|
|
// Load from cache if exists
|
|
c.domainIDMu.Lock()
|
|
id, ok := c.domainIDMapping[name]
|
|
c.domainIDMu.Unlock()
|
|
if ok {
|
|
return id, nil
|
|
}
|
|
|
|
// Find out by querying API
|
|
domains, err := c.listDomains(ctx)
|
|
if err != nil {
|
|
return domainNotFound, err
|
|
}
|
|
|
|
// Linear search over all registered domains
|
|
for _, domain := range domains {
|
|
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
|
|
c.domainIDMu.Lock()
|
|
c.domainIDMapping[name] = domain.ID
|
|
c.domainIDMu.Unlock()
|
|
|
|
return domain.ID, nil
|
|
}
|
|
}
|
|
|
|
return domainNotFound, errors.New("domain not found")
|
|
}
|
|
|
|
func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) {
|
|
endpoint := c.BaseURL.JoinPath("v1", "domains")
|
|
|
|
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
|
|
// But that functionality is kinda broken,
|
|
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
|
|
q := endpoint.Query()
|
|
q.Set("limit", strconv.Itoa(maxLimit))
|
|
|
|
currentPage := 1
|
|
totalPages := maxInt
|
|
|
|
var domainList []*Domain
|
|
for currentPage <= totalPages {
|
|
q.Set("page", strconv.Itoa(currentPage))
|
|
endpoint.RawQuery = q.Encode()
|
|
|
|
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to make request: %w", err)
|
|
}
|
|
|
|
var res DomainListingResponse
|
|
if err := c.do(req, &res); err != nil {
|
|
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
|
|
}
|
|
|
|
// This is the first response,
|
|
// so we update totalPages and allocate the slice memory.
|
|
if totalPages == maxInt {
|
|
totalPages = res.Pages
|
|
domainList = make([]*Domain, 0, res.Total)
|
|
}
|
|
|
|
domainList = append(domainList, res.Embedded.Domains...)
|
|
currentPage++
|
|
}
|
|
|
|
return domainList, nil
|
|
}
|
|
|
|
func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) {
|
|
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers")
|
|
|
|
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res := &NameserverResponse{}
|
|
if err := c.do(req, res); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (c *Client) CheckNameservers(ctx context.Context, domainID int) error {
|
|
info, err := c.getNameserverInfo(ctx, domainID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var found1, found2 bool
|
|
for _, item := range info.Nameservers {
|
|
switch item.Name {
|
|
case ns1:
|
|
found1 = true
|
|
case ns2:
|
|
found2 = true
|
|
}
|
|
}
|
|
|
|
if !found1 || !found2 {
|
|
return errors.New("not using checkdomain nameservers, can not update records")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error {
|
|
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
|
|
|
|
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return c.do(req, nil)
|
|
}
|
|
|
|
// DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
|
|
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
|
|
// TODO: Simplify this function once Checkdomain do provide the functionality.
|
|
func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error {
|
|
domainInfo, err := c.getDomainInfo(ctx, domainID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
nsInfo, err := c.getNameserverInfo(ctx, domainID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
allRecords, err := c.listRecords(ctx, domainID, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
|
|
|
|
var recordsToKeep []*Record
|
|
|
|
// Find and delete matching records
|
|
for _, record := range allRecords {
|
|
if skipRecord(recordName, recordValue, record, nsInfo) {
|
|
continue
|
|
}
|
|
|
|
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
|
|
// The API Call to replace the records would fail if we wouldn't specify a value.
|
|
// Thus, we use the default TTL queried beforehand
|
|
if record.TTL == 0 {
|
|
record.TTL = nsInfo.SOA.TTL
|
|
}
|
|
|
|
recordsToKeep = append(recordsToKeep, record)
|
|
}
|
|
|
|
return c.replaceRecords(ctx, domainID, recordsToKeep)
|
|
}
|
|
|
|
func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) {
|
|
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID))
|
|
|
|
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var res DomainResponse
|
|
err = c.do(req, &res)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &res, nil
|
|
}
|
|
|
|
func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) {
|
|
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
|
|
|
|
q := endpoint.Query()
|
|
q.Set("limit", strconv.Itoa(maxLimit))
|
|
if recordType != "" {
|
|
q.Set("type", recordType)
|
|
}
|
|
|
|
currentPage := 1
|
|
totalPages := maxInt
|
|
|
|
var recordList []*Record
|
|
for currentPage <= totalPages {
|
|
q.Set("page", strconv.Itoa(currentPage))
|
|
endpoint.RawQuery = q.Encode()
|
|
|
|
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
var res RecordListingResponse
|
|
if err := c.do(req, &res); err != nil {
|
|
return nil, fmt.Errorf("failed to send record listing request: %w", err)
|
|
}
|
|
|
|
// This is the first response, so we update totalPages and allocate the slice memory.
|
|
if totalPages == maxInt {
|
|
totalPages = res.Pages
|
|
recordList = make([]*Record, 0, res.Total)
|
|
}
|
|
|
|
recordList = append(recordList, res.Embedded.Records...)
|
|
currentPage++
|
|
}
|
|
|
|
return recordList, nil
|
|
}
|
|
|
|
func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error {
|
|
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
|
|
|
|
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return c.do(req, nil)
|
|
}
|
|
|
|
func (c *Client) do(req *http.Request, result any) error {
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return errutils.NewHTTPDoError(req, err)
|
|
}
|
|
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode/100 != 2 {
|
|
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
|
|
}
|
|
|
|
if result == nil {
|
|
return nil
|
|
}
|
|
|
|
raw, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return errutils.NewReadResponseError(req, resp.StatusCode, err)
|
|
}
|
|
|
|
err = json.Unmarshal(raw, result)
|
|
if err != nil {
|
|
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) CleanCache(fqdn string) {
|
|
c.domainIDMu.Lock()
|
|
delete(c.domainIDMapping, fqdn)
|
|
c.domainIDMu.Unlock()
|
|
}
|
|
|
|
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
|
|
// Skip empty records
|
|
if record.Value == "" {
|
|
return true
|
|
}
|
|
|
|
// Skip some special records, otherwise we would get a "Nameserver update failed"
|
|
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
|
|
return true
|
|
}
|
|
|
|
nameMatch := recordName == "" || record.Name == recordName
|
|
valueMatch := recordValue == "" || record.Value == recordValue
|
|
|
|
// Skip our matching record
|
|
if record.Type == "TXT" && nameMatch && valueMatch {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
|
|
buf := new(bytes.Buffer)
|
|
|
|
if payload != nil {
|
|
err := json.NewEncoder(buf).Encode(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
if payload != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
|
|
if client == nil {
|
|
client = &http.Client{Timeout: 5 * time.Second}
|
|
}
|
|
|
|
client.Transport = &oauth2.Transport{
|
|
Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
|
|
Base: client.Transport,
|
|
}
|
|
|
|
return client
|
|
}
|