package internal import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strconv" "strings" "sync" "time" "github.com/go-acme/lego/v4/providers/dns/internal/errutils" "golang.org/x/oauth2" ) const ( ns1 = "ns.checkdomain.de" ns2 = "ns2.checkdomain.de" ) // DefaultEndpoint the default API endpoint. const DefaultEndpoint = "https://api.checkdomain.de" const domainNotFound = -1 // max page limit that the checkdomain api allows. const maxLimit = 100 // max integer value. const maxInt = int((^uint(0)) >> 1) // Client the Autodns API client. type Client struct { domainIDMapping map[string]int domainIDMu sync.Mutex BaseURL *url.URL httpClient *http.Client } // NewClient creates a new Client. func NewClient(hc *http.Client) *Client { baseURL, _ := url.Parse(DefaultEndpoint) if hc == nil { hc = &http.Client{Timeout: 10 * time.Second} } return &Client{ BaseURL: baseURL, httpClient: hc, domainIDMapping: make(map[string]int), } } func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) { // Load from cache if exists c.domainIDMu.Lock() id, ok := c.domainIDMapping[name] c.domainIDMu.Unlock() if ok { return id, nil } // Find out by querying API domains, err := c.listDomains(ctx) if err != nil { return domainNotFound, err } // Linear search over all registered domains for _, domain := range domains { if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) { c.domainIDMu.Lock() c.domainIDMapping[name] = domain.ID c.domainIDMu.Unlock() return domain.ID, nil } } return domainNotFound, errors.New("domain not found") } func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) { endpoint := c.BaseURL.JoinPath("v1", "domains") // Checkdomain also provides a query param 'query' which allows filtering domains for a string. // But that functionality is kinda broken, // so we scan through the whole list of registered domains to later find the one that is of interest to us. q := endpoint.Query() q.Set("limit", strconv.Itoa(maxLimit)) currentPage := 1 totalPages := maxInt var domainList []*Domain for currentPage <= totalPages { q.Set("page", strconv.Itoa(currentPage)) endpoint.RawQuery = q.Encode() req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } var res DomainListingResponse if err := c.do(req, &res); err != nil { return nil, fmt.Errorf("failed to send domain listing request: %w", err) } // This is the first response, // so we update totalPages and allocate the slice memory. if totalPages == maxInt { totalPages = res.Pages domainList = make([]*Domain, 0, res.Total) } domainList = append(domainList, res.Embedded.Domains...) currentPage++ } return domainList, nil } func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) { endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers") req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } res := &NameserverResponse{} if err := c.do(req, res); err != nil { return nil, err } return res, nil } func (c *Client) CheckNameservers(ctx context.Context, domainID int) error { info, err := c.getNameserverInfo(ctx, domainID) if err != nil { return err } var found1, found2 bool for _, item := range info.Nameservers { switch item.Name { case ns1: found1 = true case ns2: found2 = true } } if !found1 || !found2 { return errors.New("not using checkdomain nameservers, can not update records") } return nil } func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error { endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records") req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return err } return c.do(req, nil) } // DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once. // The current solution is to fetch all records and then use that list minus the record deleted as the new record list. // TODO: Simplify this function once Checkdomain do provide the functionality. func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error { domainInfo, err := c.getDomainInfo(ctx, domainID) if err != nil { return err } nsInfo, err := c.getNameserverInfo(ctx, domainID) if err != nil { return err } allRecords, err := c.listRecords(ctx, domainID, "") if err != nil { return err } recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".") var recordsToKeep []*Record // Find and delete matching records for _, record := range allRecords { if skipRecord(recordName, recordValue, record, nsInfo) { continue } // Checkdomain API can return records without any TTL set (indicated by the value of 0). // The API Call to replace the records would fail if we wouldn't specify a value. // Thus, we use the default TTL queried beforehand if record.TTL == 0 { record.TTL = nsInfo.SOA.TTL } recordsToKeep = append(recordsToKeep, record) } return c.replaceRecords(ctx, domainID, recordsToKeep) } func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) { endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID)) req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } var res DomainResponse err = c.do(req, &res) if err != nil { return nil, err } return &res, nil } func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) { endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records") q := endpoint.Query() q.Set("limit", strconv.Itoa(maxLimit)) if recordType != "" { q.Set("type", recordType) } currentPage := 1 totalPages := maxInt var recordList []*Record for currentPage <= totalPages { q.Set("page", strconv.Itoa(currentPage)) endpoint.RawQuery = q.Encode() req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } var res RecordListingResponse if err := c.do(req, &res); err != nil { return nil, fmt.Errorf("failed to send record listing request: %w", err) } // This is the first response, so we update totalPages and allocate the slice memory. if totalPages == maxInt { totalPages = res.Pages recordList = make([]*Record, 0, res.Total) } recordList = append(recordList, res.Embedded.Records...) currentPage++ } return recordList, nil } func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error { endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records") req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records) if err != nil { return err } return c.do(req, nil) } func (c *Client) do(req *http.Request, result any) error { resp, err := c.httpClient.Do(req) if err != nil { return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } if result == nil { return nil } raw, err := io.ReadAll(resp.Body) if err != nil { return errutils.NewReadResponseError(req, resp.StatusCode, err) } err = json.Unmarshal(raw, result) if err != nil { return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil } func (c *Client) CleanCache(fqdn string) { c.domainIDMu.Lock() delete(c.domainIDMapping, fqdn) c.domainIDMu.Unlock() } func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool { // Skip empty records if record.Value == "" { return true } // Skip some special records, otherwise we would get a "Nameserver update failed" if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") { return true } nameMatch := recordName == "" || record.Name == recordName valueMatch := recordValue == "" || record.Value == recordValue // Skip our matching record if record.Type == "TXT" && nameMatch && valueMatch { return true } return false } func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { buf := new(bytes.Buffer) if payload != nil { err := json.NewEncoder(buf).Encode(payload) if err != nil { return nil, fmt.Errorf("failed to create request JSON body: %w", err) } } req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") if payload != nil { req.Header.Set("Content-Type", "application/json") } return req, nil } func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { if client == nil { client = &http.Client{Timeout: 5 * time.Second} } client.Transport = &oauth2.Transport{ Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), Base: client.Transport, } return client }