chore: refactor clients (#1868)

This commit is contained in:
Ludovic Fernandez 2023-05-05 09:49:38 +02:00 committed by GitHub
parent 0c1303c1bd
commit aeec5be129
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
431 changed files with 16635 additions and 10176 deletions

View file

@ -161,7 +161,7 @@ issues:
linters:
- gocyclo
- funlen
- path: providers/dns/checkdomain/client.go
- path: providers/dns/checkdomain/internal/types.go
text: '`payed` is a misspelling of `paid`'
- path: providers/dns/namecheap/namecheap_test.go
text: 'cognitive complexity (\d+) of func `TestDNSProvider_getHosts` is high'
@ -174,7 +174,7 @@ issues:
text: 'yodaStyleExpr'
- path: providers/dns/dns_providers.go
text: 'Function name: NewDNSChallengeProviderByName,'
- path: providers/dns/sakuracloud/client.go
- path: providers/dns/sakuracloud/wrapper.go
text: 'mu is a global variable'
- path: providers/dns/hosttech/internal/client_test.go
text: 'Duplicate words \(0\) found'

View file

@ -51,7 +51,7 @@ Detailed documentation is available [here](https://go-acme.github.io/lego/dns).
|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|
| [Akamai EdgeDNS](https://go-acme.github.io/lego/dns/edgedns/) | [Alibaba Cloud DNS](https://go-acme.github.io/lego/dns/alidns/) | [all-inkl](https://go-acme.github.io/lego/dns/allinkl/) | [Amazon Lightsail](https://go-acme.github.io/lego/dns/lightsail/) |
| [Amazon Route 53](https://go-acme.github.io/lego/dns/route53/) | [ArvanCloud](https://go-acme.github.io/lego/dns/arvancloud/) | [Aurora DNS](https://go-acme.github.io/lego/dns/auroradns/) | [Autodns](https://go-acme.github.io/lego/dns/autodns/) |
| [Azure](https://go-acme.github.io/lego/dns/azure/) | [Bindman](https://go-acme.github.io/lego/dns/bindman/) | [Bluecat](https://go-acme.github.io/lego/dns/bluecat/) | [BRANDIT](https://go-acme.github.io/lego/dns/brandit/) |
| [Azure](https://go-acme.github.io/lego/dns/azure/) | [Bindman](https://go-acme.github.io/lego/dns/bindman/) | [Bluecat](https://go-acme.github.io/lego/dns/bluecat/) | [Brandit](https://go-acme.github.io/lego/dns/brandit/) |
| [Bunny](https://go-acme.github.io/lego/dns/bunny/) | [Checkdomain](https://go-acme.github.io/lego/dns/checkdomain/) | [Civo](https://go-acme.github.io/lego/dns/civo/) | [CloudDNS](https://go-acme.github.io/lego/dns/clouddns/) |
| [Cloudflare](https://go-acme.github.io/lego/dns/cloudflare/) | [ClouDNS](https://go-acme.github.io/lego/dns/cloudns/) | [CloudXNS](https://go-acme.github.io/lego/dns/cloudxns/) | [ConoHa](https://go-acme.github.io/lego/dns/conoha/) |
| [Constellix](https://go-acme.github.io/lego/dns/constellix/) | [deSEC.io](https://go-acme.github.io/lego/dns/desec/) | [Designate DNSaaS for Openstack](https://go-acme.github.io/lego/dns/designate/) | [Digital Ocean](https://go-acme.github.io/lego/dns/digitalocean/) |

View file

@ -335,7 +335,7 @@ func displayDNSHelp(w io.Writer, name string) error {
case "brandit":
// generated from: providers/dns/brandit/brandit.toml
ew.writeln(`Configuration for BRANDIT.`)
ew.writeln(`Configuration for Brandit.`)
ew.writeln(`Code: 'brandit'`)
ew.writeln(`Since: 'v4.11.0'`)
ew.writeln()

View file

@ -1,5 +1,5 @@
---
title: "BRANDIT"
title: "Brandit"
date: 2019-03-03T16:39:46+01:00
draft: false
slug: brandit
@ -14,7 +14,7 @@ dnsprovider:
<!-- THIS DOCUMENTATION IS AUTO-GENERATED. PLEASE DO NOT EDIT. -->
Configuration for [BRANDIT](https://www.brandit.com/).
Configuration for [Brandit](https://www.brandit.com/).
<!--more-->
@ -23,7 +23,7 @@ Configuration for [BRANDIT](https://www.brandit.com/).
- Since: v4.11.0
Here is an example bash command using the BRANDIT provider:
Here is an example bash command using the Brandit provider:
```bash
BRANDIT_API_KEY=xxxxxxxxxxxxxxxxxxxxx \

View file

@ -61,7 +61,7 @@ More information [here]({{< ref "dns#configuration-and-credentials" >}}).
## More information
- [API documentation](https://docs.otc.t-systems.com/en-us/dns/index.html)
- [API documentation](https://docs.otc.t-systems.com/domain-name-service/api-ref/index.html)
<!-- THIS DOCUMENTATION IS AUTO-GENERATED. PLEASE DO NOT EDIT. -->
<!-- providers/dns/otc/otc.toml -->

14
go.mod
View file

@ -63,9 +63,9 @@ require (
github.com/vultr/govultr/v2 v2.17.2
github.com/yandex-cloud/go-genproto v0.0.0-20220805142335-27b56ddae16f
github.com/yandex-cloud/go-sdk v0.0.0-20220805164847-cf028e604997
golang.org/x/crypto v0.5.0
golang.org/x/net v0.7.0
golang.org/x/oauth2 v0.5.0
golang.org/x/crypto v0.7.0
golang.org/x/net v0.8.0
golang.org/x/oauth2 v0.6.0
golang.org/x/time v0.3.0
google.golang.org/api v0.111.0
gopkg.in/ns1/ns1-go.v2 v2.6.5
@ -126,10 +126,10 @@ require (
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.opencensus.io v0.24.0 // indirect
go.uber.org/ratelimit v0.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.7.0 // indirect
golang.org/x/tools v0.1.12 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.6.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230223222841-637eb2293923 // indirect
google.golang.org/grpc v1.53.0 // indirect

30
go.sum
View file

@ -595,8 +595,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -619,8 +619,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -650,14 +650,14 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.5.0 h1:HuArIo48skDwlrvM3sEdHXElYslAMsf3KwRkkW4MC4s=
golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I=
golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw=
golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -712,12 +712,12 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY=
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@ -726,8 +726,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -758,8 +758,8 @@ golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4X
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210114065538-d78b04bdf963/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View file

@ -1,7 +1,6 @@
package wait
import (
"errors"
"fmt"
"time"
@ -18,9 +17,9 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er
select {
case <-timeUp:
if lastErr == nil {
return errors.New("time limit exceeded")
return fmt.Errorf("%s: time limit exceeded", msg)
}
return fmt.Errorf("time limit exceeded: last error: %w", lastErr)
return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr)
default:
}

View file

@ -198,7 +198,7 @@ func (d *DNSProvider) getHostedZone(domain string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(domain)
if err != nil {
return "", err
return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err)
}
var hostedZone alidns.DomainInDescribeDomains

View file

@ -2,6 +2,7 @@
package allinkl
import (
"context"
"errors"
"fmt"
"net/http"
@ -49,6 +50,8 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
identifier *internal.Identifier
client *internal.Client
recordIDs map[string]string
@ -80,7 +83,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("allinkl: missing credentials")
}
client := internal.NewClient(config.Login, config.Password)
identifier := internal.NewIdentifier(config.Login, config.Password)
if config.HTTPClient != nil {
identifier.HTTPClient = config.HTTPClient
}
client := internal.NewClient(config.Login)
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
@ -88,6 +97,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return &DNSProvider{
config: config,
identifier: identifier,
client: client,
recordIDs: make(map[string]string),
}, nil
@ -105,14 +115,18 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("allinkl: could not determine zone for domain %q: %w", domain, err)
return fmt.Errorf("allinkl: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
credential, err := d.client.Authentication(60, true)
ctx := context.Background()
credential, err := d.identifier.Authentication(ctx, 60, true)
if err != nil {
return fmt.Errorf("allinkl: %w", err)
}
ctx = internal.WithContext(ctx, credential)
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil {
return fmt.Errorf("allinkl: %w", err)
@ -125,7 +139,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
RecordData: info.Value,
}
recordID, err := d.client.AddDNSSettings(credential, record)
recordID, err := d.client.AddDNSSettings(ctx, record)
if err != nil {
return fmt.Errorf("allinkl: %w", err)
}
@ -141,11 +155,15 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
credential, err := d.client.Authentication(60, true)
ctx := context.Background()
credential, err := d.identifier.Authentication(ctx, 60, true)
if err != nil {
return fmt.Errorf("allinkl: %w", err)
}
ctx = internal.WithContext(ctx, credential)
// gets the record's unique ID from when we created it
d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token]
@ -154,7 +172,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("allinkl: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
}
_, err = d.client.DeleteDNSSettings(credential, recordID)
_, err = d.client.DeleteDNSSettings(ctx, recordID)
if err != nil {
return fmt.Errorf("allinkl: %w", err)
}

View file

@ -2,126 +2,64 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"github.com/mitchellh/mapstructure"
)
const (
authEndpoint = "https://kasapi.kasserver.com/soap/KasAuth.php"
apiEndpoint = "https://kasapi.kasserver.com/soap/KasApi.php"
)
const apiEndpoint = "https://kasapi.kasserver.com/soap/KasApi.php"
type Authentication interface {
Authentication(ctx context.Context, sessionLifetime int, sessionUpdateLifetime bool) (string, error)
}
// Client a KAS server client.
type Client struct {
login string
password string
authEndpoint string
apiEndpoint string
HTTPClient *http.Client
floodTime time.Time
muFloodTime sync.Mutex
baseURL string
HTTPClient *http.Client
}
// NewClient creates a new Client.
func NewClient(login string, password string) *Client {
func NewClient(login string) *Client {
return &Client{
login: login,
password: password,
authEndpoint: authEndpoint,
apiEndpoint: apiEndpoint,
baseURL: apiEndpoint,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}
}
// Authentication Creates a credential token.
// - sessionLifetime: Validity of the token in seconds.
// - sessionUpdateLifetime: with `true` the session is extended with every request.
func (c Client) Authentication(sessionLifetime int, sessionUpdateLifetime bool) (string, error) {
sul := "N"
if sessionUpdateLifetime {
sul = "Y"
}
ar := AuthRequest{
Login: c.login,
AuthData: c.password,
AuthType: "plain",
SessionLifetime: sessionLifetime,
SessionUpdateLifetime: sul,
}
body, err := json.Marshal(ar)
if err != nil {
return "", fmt.Errorf("request marshal: %w", err)
}
payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAuthEnvelope, body)))
req, err := http.NewRequest(http.MethodPost, c.authEndpoint, bytes.NewReader(payload))
if err != nil {
return "", fmt.Errorf("request creation: %w", err)
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return "", fmt.Errorf("request execution: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("invalid status code: %d %s", resp.StatusCode, string(data))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("response read: %w", err)
}
var e KasAuthEnvelope
decoder := xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(data))})
err = decoder.Decode(&e)
if err != nil {
return "", fmt.Errorf("response xml decode: %w", err)
}
if e.Body.Fault != nil {
return "", e.Body.Fault
}
return e.Body.KasAuthResponse.Return.Text, nil
}
// GetDNSSettings Reading out the DNS settings of a zone.
// - zone: host zone.
// - recordID: the ID of the resource record (optional).
func (c *Client) GetDNSSettings(credentialToken, zone, recordID string) ([]ReturnInfo, error) {
func (c *Client) GetDNSSettings(ctx context.Context, zone, recordID string) ([]ReturnInfo, error) {
requestParams := map[string]string{"zone_host": zone}
if recordID != "" {
requestParams["record_id"] = recordID
}
item, err := c.do(credentialToken, "get_dns_settings", requestParams)
req, err := c.newRequest(ctx, "get_dns_settings", requestParams)
if err != nil {
return nil, err
}
raw := getValue(item)
var g GetDNSSettingsAPIResponse
err = mapstructure.Decode(raw, &g)
err = c.do(req, &g)
if err != nil {
return nil, fmt.Errorf("response struct decode: %w", err)
return nil, err
}
c.updateFloodTime(g.Response.KasFloodDelay)
@ -130,18 +68,16 @@ func (c *Client) GetDNSSettings(credentialToken, zone, recordID string) ([]Retur
}
// AddDNSSettings Creation of a DNS resource record.
func (c *Client) AddDNSSettings(credentialToken string, record DNSRequest) (string, error) {
item, err := c.do(credentialToken, "add_dns_settings", record)
func (c *Client) AddDNSSettings(ctx context.Context, record DNSRequest) (string, error) {
req, err := c.newRequest(ctx, "add_dns_settings", record)
if err != nil {
return "", err
}
raw := getValue(item)
var g AddDNSSettingsAPIResponse
err = mapstructure.Decode(raw, &g)
err = c.do(req, &g)
if err != nil {
return "", fmt.Errorf("response struct decode: %w", err)
return "", err
}
c.updateFloodTime(g.Response.KasFloodDelay)
@ -150,20 +86,18 @@ func (c *Client) AddDNSSettings(credentialToken string, record DNSRequest) (stri
}
// DeleteDNSSettings Deleting a DNS Resource Record.
func (c *Client) DeleteDNSSettings(credentialToken, recordID string) (bool, error) {
func (c *Client) DeleteDNSSettings(ctx context.Context, recordID string) (bool, error) {
requestParams := map[string]string{"record_id": recordID}
item, err := c.do(credentialToken, "delete_dns_settings", requestParams)
req, err := c.newRequest(ctx, "delete_dns_settings", requestParams)
if err != nil {
return false, err
}
raw := getValue(item)
var g DeleteDNSSettingsAPIResponse
err = mapstructure.Decode(raw, &g)
err = c.do(req, &g)
if err != nil {
return false, fmt.Errorf("response struct decode: %w", err)
return false, err
}
c.updateFloodTime(g.Response.KasFloodDelay)
@ -171,65 +105,72 @@ func (c *Client) DeleteDNSSettings(credentialToken, recordID string) (bool, erro
return g.Response.ReturnInfo, nil
}
func (c Client) do(credentialToken, action string, requestParams interface{}) (*Item, error) {
time.Sleep(time.Until(c.floodTime))
func (c *Client) newRequest(ctx context.Context, action string, requestParams any) (*http.Request, error) {
ar := KasRequest{
Login: c.login,
AuthType: "session",
AuthData: credentialToken,
AuthData: getToken(ctx),
Action: action,
RequestParams: requestParams,
}
body, err := json.Marshal(ar)
if err != nil {
return nil, fmt.Errorf("request marshal: %w", err)
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAPIEnvelope, body)))
req, err := http.NewRequest(http.MethodPost, c.apiEndpoint, bytes.NewReader(payload))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, bytes.NewReader(payload))
if err != nil {
return nil, fmt.Errorf("request creation: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
return req, nil
}
func (c *Client) do(req *http.Request, result any) error {
c.muFloodTime.Lock()
time.Sleep(time.Until(c.floodTime))
c.muFloodTime.Unlock()
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request execution: %w", err)
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("invalid status code: %d %s", resp.StatusCode, string(data))
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
data, err := io.ReadAll(resp.Body)
envlp, err := decodeXML[KasAPIResponseEnvelope](resp.Body)
if err != nil {
return nil, fmt.Errorf("response read: %w", err)
return err
}
var e KasAPIResponseEnvelope
decoder := xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(data))})
err = decoder.Decode(&e)
if envlp.Body.Fault != nil {
return envlp.Body.Fault
}
raw := getValue(envlp.Body.KasAPIResponse.Return)
err = mapstructure.Decode(raw, result)
if err != nil {
return nil, fmt.Errorf("response xml decode: %w", err)
return fmt.Errorf("response struct decode: %w", err)
}
if e.Body.Fault != nil {
return nil, e.Body.Fault
}
return e.Body.KasAPIResponse.Return, nil
return nil
}
func (c *Client) updateFloodTime(delay float64) {
c.muFloodTime.Lock()
c.floodTime = time.Now().Add(time.Duration(delay * float64(time.Second)))
c.muFloodTime.Unlock()
}
func getValue(item *Item) interface{} {
func getValue(item *Item) any {
switch {
case item.Raw != "":
v, _ := strconv.ParseBool(item.Raw)
@ -253,7 +194,7 @@ func getValue(item *Item) interface{} {
return getValue(item.Value)
case len(item.Items) > 0 && item.Type == "SOAP-ENC:Array":
var v []interface{}
var v []any
for _, i := range item.Items {
v = append(v, getValue(i))
}
@ -261,7 +202,7 @@ func getValue(item *Item) interface{} {
return v
case len(item.Items) > 0:
v := map[string]interface{}{}
v := map[string]any{}
for _, i := range item.Items {
v[getKey(i)] = getValue(i)
}

View file

@ -13,36 +13,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_Authentication(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth.xml"))
client := NewClient("user", "secret")
client.authEndpoint = server.URL
credentialToken, err := client.Authentication(60, false)
require.NoError(t, err)
assert.Equal(t, "593959ca04f0de9689b586c6a647d15d", credentialToken)
}
func TestClient_Authentication_error(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth_fault.xml"))
client := NewClient("user", "secret")
client.authEndpoint = server.URL
_, err := client.Authentication(60, false)
require.Error(t, err)
}
func TestClient_GetDNSSettings(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
@ -50,12 +20,10 @@ func TestClient_GetDNSSettings(t *testing.T) {
mux.HandleFunc("/", testHandler("get_dns_settings.xml"))
client := NewClient("user", "secret")
client.apiEndpoint = server.URL
client := NewClient("user")
client.baseURL = server.URL
token := "sha1secret"
records, err := client.GetDNSSettings(token, "example.com", "")
records, err := client.GetDNSSettings(mockContext(), "example.com", "")
require.NoError(t, err)
expected := []ReturnInfo{
@ -134,10 +102,8 @@ func TestClient_AddDNSSettings(t *testing.T) {
mux.HandleFunc("/", testHandler("add_dns_settings.xml"))
client := NewClient("user", "secret")
client.apiEndpoint = server.URL
token := "sha1secret"
client := NewClient("user")
client.baseURL = server.URL
record := DNSRequest{
ZoneHost: "42cnc.de.",
@ -146,7 +112,7 @@ func TestClient_AddDNSSettings(t *testing.T) {
RecordData: "abcdefgh",
}
recordID, err := client.AddDNSSettings(token, record)
recordID, err := client.AddDNSSettings(mockContext(), record)
require.NoError(t, err)
assert.Equal(t, "57347444", recordID)
@ -159,12 +125,10 @@ func TestClient_DeleteDNSSettings(t *testing.T) {
mux.HandleFunc("/", testHandler("delete_dns_settings.xml"))
client := NewClient("user", "secret")
client.apiEndpoint = server.URL
client := NewClient("user")
client.baseURL = server.URL
token := "sha1secret"
r, err := client.DeleteDNSSettings(token, "57347450")
r, err := client.DeleteDNSSettings(mockContext(), "57347450")
require.NoError(t, err)
assert.True(t, r)

View file

@ -0,0 +1,104 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
// authEndpoint represents the Identity API endpoint to call.
const authEndpoint = "https://kasapi.kasserver.com/soap/KasAuth.php"
type token string
const tokenKey token = "token"
// Identifier generates credential tokens.
type Identifier struct {
login string
password string
authEndpoint string
HTTPClient *http.Client
}
// NewIdentifier creates a new Identifier.
func NewIdentifier(login string, password string) *Identifier {
return &Identifier{
login: login,
password: password,
authEndpoint: authEndpoint,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}
}
// Authentication Creates a credential token.
// - sessionLifetime: Validity of the token in seconds.
// - sessionUpdateLifetime: with `true` the session is extended with every request.
func (c *Identifier) Authentication(ctx context.Context, sessionLifetime int, sessionUpdateLifetime bool) (string, error) {
sul := "N"
if sessionUpdateLifetime {
sul = "Y"
}
ar := AuthRequest{
Login: c.login,
AuthData: c.password,
AuthType: "plain",
SessionLifetime: sessionLifetime,
SessionUpdateLifetime: sul,
}
body, err := json.Marshal(ar)
if err != nil {
return "", fmt.Errorf("failed to create request JSON body: %w", err)
}
payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAuthEnvelope, body)))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.authEndpoint, bytes.NewReader(payload))
if err != nil {
return "", fmt.Errorf("unable to create request: %w", err)
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return "", errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
envlp, err := decodeXML[KasAuthEnvelope](resp.Body)
if err != nil {
return "", err
}
if envlp.Body.Fault != nil {
return "", envlp.Body.Fault
}
return envlp.Body.KasAuthResponse.Return.Text, nil
}
func WithContext(ctx context.Context, credential string) context.Context {
return context.WithValue(ctx, tokenKey, credential)
}
func getToken(ctx context.Context) string {
credential, ok := ctx.Value(tokenKey).(string)
if !ok {
return ""
}
return credential
}

View file

@ -0,0 +1,45 @@
package internal
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func mockContext() context.Context {
return context.WithValue(context.Background(), tokenKey, "593959ca04f0de9689b586c6a647d15d")
}
func TestIdentifier_Authentication(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth.xml"))
client := NewIdentifier("user", "secret")
client.authEndpoint = server.URL
credentialToken, err := client.Authentication(context.Background(), 60, false)
require.NoError(t, err)
assert.Equal(t, "593959ca04f0de9689b586c6a647d15d", credentialToken)
}
func TestIdentifier_Authentication_error(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth_fault.xml"))
client := NewIdentifier("user", "secret")
client.authEndpoint = server.URL
_, err := client.Authentication(context.Background(), 60, false)
require.Error(t, err)
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/xml"
"fmt"
"io"
)
// Trimmer trim all XML fields.
@ -44,3 +45,18 @@ type Item struct {
Value *Item `xml:"value" json:"value,omitempty"`
Items []*Item `xml:"item" json:"item,omitempty"`
}
func decodeXML[T any](reader io.Reader) (*T, error) {
raw, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var result T
err = xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(raw))}).Decode(&result)
if err != nil {
return nil, fmt.Errorf("decode XML response: %w", err)
}
return &result, nil
}

View file

@ -35,7 +35,7 @@ type KasRequest struct {
// Action API function.
Action string `json:"kas_action,omitempty"`
// RequestParams Parameters to the API function.
RequestParams interface{} `json:"KasRequestParams,omitempty"`
RequestParams any `json:"KasRequestParams,omitempty"`
}
type DNSRequest struct {
@ -64,7 +64,7 @@ type GetDNSSettingsResponse struct {
}
type ReturnInfo struct {
ID interface{} `json:"record_id,omitempty" mapstructure:"record_id"`
ID any `json:"record_id,omitempty" mapstructure:"record_id"`
Zone string `json:"record_zone,omitempty" mapstructure:"record_zone"`
Name string `json:"record_name,omitempty" mapstructure:"record_name"`
Type string `json:"record_type,omitempty" mapstructure:"record_type"`

View file

@ -2,6 +2,7 @@
package arvancloud
import (
"context"
"errors"
"fmt"
"net/http"
@ -108,11 +109,13 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := getZone(info.EffectiveFQDN)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("arvancloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
authZone = dns01.UnFqdn(authZone)
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil {
return fmt.Errorf("arvancloud: %w", err)
@ -131,7 +134,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
},
}
newRecord, err := d.client.CreateRecord(authZone, record)
newRecord, err := d.client.CreateRecord(context.Background(), authZone, record)
if err != nil {
return fmt.Errorf("arvancloud: failed to add TXT record: fqdn=%s: %w", info.EffectiveFQDN, err)
}
@ -147,11 +150,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := getZone(info.EffectiveFQDN)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("arvancloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
authZone = dns01.UnFqdn(authZone)
// gets the record's unique ID from when we created it
d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token]
@ -160,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("arvancloud: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
}
if err := d.client.DeleteRecord(authZone, recordID); err != nil {
if err := d.client.DeleteRecord(context.Background(), authZone, recordID); err != nil {
return fmt.Errorf("arvancloud: failed to delate TXT record: id=%s: %w", recordID, err)
}
@ -171,12 +176,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil
}
func getZone(fqdn string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
}
return dns01.UnFqdn(authZone), nil
}

View file

@ -2,39 +2,45 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
// defaultBaseURL represents the API endpoint to call.
const defaultBaseURL = "https://napi.arvancloud.ir"
const authHeader = "Authorization"
const authorizationHeader = "Authorization"
// Client the ArvanCloud client.
type Client struct {
HTTPClient *http.Client
BaseURL string
apiKey string
baseURL *url.URL
HTTPClient *http.Client
}
// NewClient Creates a new ArvanCloud client.
// NewClient Creates a new Client.
func NewClient(apiKey string) *Client {
baseURL, _ := url.Parse(defaultBaseURL)
return &Client{
HTTPClient: http.DefaultClient,
BaseURL: defaultBaseURL,
apiKey: apiKey,
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}
}
// GetTxtRecord gets a TXT record.
func (c *Client) GetTxtRecord(domain, name, value string) (*DNSRecord, error) {
records, err := c.getRecords(domain, name)
func (c *Client) GetTxtRecord(ctx context.Context, domain, name, value string) (*DNSRecord, error) {
records, err := c.getRecords(ctx, domain, name)
if err != nil {
return nil, err
}
@ -49,11 +55,8 @@ func (c *Client) GetTxtRecord(domain, name, value string) (*DNSRecord, error) {
}
// https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.list
func (c *Client) getRecords(domain, search string) ([]DNSRecord, error) {
endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records")
if err != nil {
return nil, fmt.Errorf("failed to create endpoint: %w", err)
}
func (c *Client) getRecords(ctx context.Context, domain, search string) ([]DNSRecord, error) {
endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records")
if search != "" {
query := endpoint.Query()
@ -61,123 +64,110 @@ func (c *Client) getRecords(domain, search string) ([]DNSRecord, error) {
endpoint.RawQuery = query.Encode()
}
resp, err := c.do(http.MethodGet, endpoint.String(), nil)
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
response := &apiResponse[[]DNSRecord]{}
err = c.do(req, http.StatusOK, response)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
return nil, fmt.Errorf("could not get records %s: Domain: %s: %w", search, domain, err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("could not get records %s: Domain: %s; Status: %s; Body: %s",
search, domain, resp.Status, string(body))
}
response := &apiResponse{}
err = json.Unmarshal(body, response)
if err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
var records []DNSRecord
err = json.Unmarshal(response.Data, &records)
if err != nil {
return nil, fmt.Errorf("failed to decode records: %w", err)
}
return records, nil
return response.Data, nil
}
// CreateRecord creates a DNS record.
// https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.create
func (c *Client) CreateRecord(domain string, record DNSRecord) (*DNSRecord, error) {
reqBody, err := json.Marshal(record)
func (c *Client) CreateRecord(ctx context.Context, domain string, record DNSRecord) (*DNSRecord, error) {
endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return nil, err
}
endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records")
response := &apiResponse[*DNSRecord]{}
err = c.do(req, http.StatusCreated, response)
if err != nil {
return nil, fmt.Errorf("failed to create endpoint: %w", err)
return nil, fmt.Errorf("could not create record; Domain: %s: %w", domain, err)
}
resp, err := c.do(http.MethodPost, endpoint.String(), bytes.NewReader(reqBody))
if err != nil {
return nil, err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("could not create record %s; Domain: %s; Status: %s; Body: %s", string(reqBody), domain, resp.Status, string(body))
}
response := &apiResponse{}
err = json.Unmarshal(body, response)
if err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
var newRecord DNSRecord
err = json.Unmarshal(response.Data, &newRecord)
if err != nil {
return nil, fmt.Errorf("failed to decode record: %w", err)
}
return &newRecord, nil
return response.Data, nil
}
// DeleteRecord deletes a DNS record.
// https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.remove
func (c *Client) DeleteRecord(domain, id string) error {
endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records", id)
if err != nil {
return fmt.Errorf("failed to create endpoint: %w", err)
}
func (c *Client) DeleteRecord(ctx context.Context, domain, id string) error {
endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records", id)
resp, err := c.do(http.MethodDelete, endpoint.String(), nil)
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("could not delete record %s; Domain: %s; Status: %s; Body: %s", id, domain, resp.Status, string(body))
err = c.do(req, http.StatusOK, nil)
if err != nil {
return fmt.Errorf("could not delete record %s; Domain: %s: %w", id, domain, err)
}
return nil
}
func (c *Client) do(method, endpoint string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest(method, endpoint, body)
func (c *Client) do(req *http.Request, expectedStatus int, result any) error {
req.Header.Set(authorizationHeader, c.apiKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, err
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != expectedStatus {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if body != nil {
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set(authHeader, c.apiKey)
return c.HTTPClient.Do(req)
}
func (c *Client) createEndpoint(parts ...string) (*url.URL, error) {
baseURL, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
}
return baseURL.JoinPath(parts...), nil
return req, nil
}
func equalsTXTRecord(record DNSRecord, name, value string) bool {
@ -189,7 +179,7 @@ func equalsTXTRecord(record DNSRecord, name, value string) bool {
return false
}
data, ok := record.Value.(map[string]interface{})
data, ok := record.Value.(map[string]any)
if !ok {
return false
}

View file

@ -1,10 +1,12 @@
package internal
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
@ -12,21 +14,34 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_GetTxtRecord(t *testing.T) {
func setupTest(t *testing.T, apiKey string) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
const domain = "example.com"
client := NewClient(apiKey)
client.baseURL, _ = url.Parse(server.URL)
client.HTTPClient = server.Client()
return client, mux
}
func TestClient_GetTxtRecord(t *testing.T) {
const apiKey = "myKeyA"
client, mux := setupTest(t, apiKey)
const domain = "example.com"
mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed)
return
}
auth := req.Header.Get(authHeader)
auth := req.Header.Get(authorizationHeader)
if auth != apiKey {
http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized)
return
@ -46,20 +61,16 @@ func TestClient_GetTxtRecord(t *testing.T) {
}
})
client := NewClient(apiKey)
client.BaseURL = server.URL
_, err := client.GetTxtRecord(domain, "_acme-challenge", "txtxtxt")
_, err := client.GetTxtRecord(context.Background(), domain, "_acme-challenge", "txtxtxt")
require.NoError(t, err)
}
func TestClient_CreateRecord(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
const apiKey = "myKeyB"
client, mux := setupTest(t, apiKey)
const domain = "example.com"
const apiKey = "myKeyB"
mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
@ -67,7 +78,7 @@ func TestClient_CreateRecord(t *testing.T) {
return
}
auth := req.Header.Get(authHeader)
auth := req.Header.Get(authorizationHeader)
if auth != apiKey {
http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized)
return
@ -88,9 +99,6 @@ func TestClient_CreateRecord(t *testing.T) {
}
})
client := NewClient(apiKey)
client.BaseURL = server.URL
record := DNSRecord{
Name: "_acme-challenge",
Type: "txt",
@ -98,7 +106,7 @@ func TestClient_CreateRecord(t *testing.T) {
TTL: 600,
}
newRecord, err := client.CreateRecord(domain, record)
newRecord, err := client.CreateRecord(context.Background(), domain, record)
require.NoError(t, err)
expected := &DNSRecord{
@ -119,12 +127,11 @@ func TestClient_CreateRecord(t *testing.T) {
}
func TestClient_DeleteRecord(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
const apiKey = "myKeyC"
client, mux := setupTest(t, apiKey)
const domain = "example.com"
const apiKey = "myKeyC"
const recordID = "recordId"
mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records/"+recordID, func(rw http.ResponseWriter, req *http.Request) {
@ -133,16 +140,13 @@ func TestClient_DeleteRecord(t *testing.T) {
return
}
auth := req.Header.Get(authHeader)
auth := req.Header.Get(authorizationHeader)
if auth != apiKey {
http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized)
return
}
})
client := NewClient(apiKey)
client.BaseURL = server.URL
err := client.DeleteRecord(domain, recordID)
err := client.DeleteRecord(context.Background(), domain, recordID)
require.NoError(t, err)
}

View file

@ -1,17 +1,15 @@
package internal
import "encoding/json"
type apiResponse struct {
type apiResponse[T any] struct {
Message string `json:"message"`
Data json.RawMessage `json:"data"`
Data T `json:"data"`
}
// DNSRecord a DNS record.
type DNSRecord struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Value interface{} `json:"value,omitempty"`
Value any `json:"value,omitempty"`
Name string `json:"name,omitempty"`
TTL int `json:"ttl,omitempty"`
UpstreamHTTPS string `json:"upstream_https,omitempty"`

View file

@ -108,7 +108,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("aurora: could not determine zone for domain %q: %w", domain, err)
return fmt.Errorf("aurora: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
// 1. Aurora will happily create the TXT record when it is provided a fqdn,
@ -155,24 +155,24 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
d.recordIDsMu.Unlock()
if !ok {
return fmt.Errorf("unknown recordID for %q", info.EffectiveFQDN)
return fmt.Errorf("aurora: unknown recordID for %q", info.EffectiveFQDN)
}
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN))
if err != nil {
return fmt.Errorf("could not determine zone for domain %q: %w", domain, err)
return fmt.Errorf("aurora: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
authZone = dns01.UnFqdn(authZone)
zone, err := d.getZoneInformationByName(authZone)
if err != nil {
return err
return fmt.Errorf("aurora: %w", err)
}
_, _, err = d.client.DeleteRecord(zone.ID, recordID)
if err != nil {
return err
return fmt.Errorf("aurora: %w", err)
}
d.recordIDsMu.Lock()

View file

@ -2,6 +2,7 @@
package autodns
import (
"context"
"errors"
"fmt"
"net/http"
@ -10,6 +11,7 @@ import (
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/autodns/internal"
)
// Environment variables names.
@ -27,11 +29,6 @@ const (
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
)
const (
defaultEndpointContext int = 4
defaultTTL int = 600
)
// Config is used to configure the creation of the DNSProvider.
type Config struct {
Endpoint *url.URL
@ -46,12 +43,12 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
endpoint, _ := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, defaultEndpoint))
endpoint, _ := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, internal.DefaultEndpoint))
return &Config{
Endpoint: endpoint,
Context: env.GetOrDefaultInt(EnvAPIEndpointContext, defaultEndpointContext),
TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL),
Context: env.GetOrDefaultInt(EnvAPIEndpointContext, internal.DefaultEndpointContext),
TTL: env.GetOrDefaultInt(EnvTTL, 600),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 2*time.Minute),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 2*time.Second),
HTTPClient: &http.Client{
@ -63,6 +60,7 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
client *internal.Client
}
// NewDNSProvider returns a DNSProvider instance configured for autoDNS.
@ -94,7 +92,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("autodns: missing password")
}
return &DNSProvider{config: config}, nil
client := internal.NewClient(config.Username, config.Password, config.Context)
if config.Endpoint != nil {
client.BaseURL = config.Endpoint
}
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
}
return &DNSProvider{config: config, client: client}, nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.
@ -107,7 +115,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
records := []*ResourceRecord{{
records := []*internal.ResourceRecord{{
Name: info.EffectiveFQDN,
TTL: int64(d.config.TTL),
Type: "TXT",
@ -115,7 +123,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
}}
// TODO(ldez) replace domain by FQDN to follow CNAME.
_, err := d.addTxtRecord(domain, records)
_, err := d.client.AddTxtRecords(context.Background(), domain, records)
if err != nil {
return fmt.Errorf("autodns: %w", err)
}
@ -127,7 +135,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
records := []*ResourceRecord{{
records := []*internal.ResourceRecord{{
Name: info.EffectiveFQDN,
TTL: int64(d.config.TTL),
Type: "TXT",
@ -135,7 +143,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
}}
// TODO(ldez) replace domain by FQDN to follow CNAME.
if err := d.removeTXTRecord(domain, records); err != nil {
if err := d.client.RemoveTXTRecords(context.Background(), domain, records); err != nil {
return fmt.Errorf("autodns: %w", err)
}

View file

@ -1,159 +0,0 @@
package autodns
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
)
const (
defaultEndpoint = "https://api.autodns.com/v1/"
)
type ResponseMessage struct {
Text string `json:"text"`
Messages []string `json:"messages"`
Objects []string `json:"objects"`
Code string `json:"code"`
Status string `json:"status"`
}
type ResponseStatus struct {
Code string `json:"code"`
Text string `json:"text"`
Type string `json:"type"`
}
type ResponseObject struct {
Type string `json:"type"`
Value string `json:"value"`
Summary int32 `json:"summary"`
Data string
}
type DataZoneResponse struct {
STID string `json:"stid"`
CTID string `json:"ctid"`
Messages []*ResponseMessage `json:"messages"`
Status *ResponseStatus `json:"status"`
Object interface{} `json:"object"`
Data []*Zone `json:"data"`
}
// ResourceRecord holds a resource record.
type ResourceRecord struct {
Name string `json:"name"`
TTL int64 `json:"ttl"`
Type string `json:"type"`
Value string `json:"value"`
Pref int32 `json:"pref,omitempty"`
}
// Zone is an autodns zone record with all for us relevant fields.
type Zone struct {
Name string `json:"origin"`
ResourceRecords []*ResourceRecord `json:"resourceRecords"`
Action string `json:"action"`
VirtualNameServer string `json:"virtualNameServer"`
}
type ZoneStream struct {
Adds []*ResourceRecord `json:"adds"`
Removes []*ResourceRecord `json:"rems"`
}
func (d *DNSProvider) addTxtRecord(domain string, records []*ResourceRecord) (*Zone, error) {
zoneStream := &ZoneStream{Adds: records}
return d.makeZoneUpdateRequest(zoneStream, domain)
}
func (d *DNSProvider) removeTXTRecord(domain string, records []*ResourceRecord) error {
zoneStream := &ZoneStream{Removes: records}
_, err := d.makeZoneUpdateRequest(zoneStream, domain)
return err
}
func (d *DNSProvider) makeZoneUpdateRequest(zoneStream *ZoneStream, domain string) (*Zone, error) {
reqBody := &bytes.Buffer{}
if err := json.NewEncoder(reqBody).Encode(zoneStream); err != nil {
return nil, err
}
endpoint := d.config.Endpoint.JoinPath("zone", domain, "_stream")
req, err := d.makeRequest(http.MethodPost, endpoint.String(), reqBody)
if err != nil {
return nil, err
}
var resp *Zone
if err := d.sendRequest(req, &resp); err != nil {
return nil, err
}
return resp, nil
}
func (d *DNSProvider) makeRequest(method, endpoint string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, endpoint, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Domainrobot-Context", strconv.Itoa(d.config.Context))
req.SetBasicAuth(d.config.Username, d.config.Password)
return req, nil
}
func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error {
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
if err = checkResponse(resp); err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(raw, result)
if err != nil {
return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw))
}
return err
}
func checkResponse(resp *http.Response) error {
if resp.StatusCode < http.StatusBadRequest {
return nil
}
if resp.Body == nil {
return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err)
}
return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw))
}

View file

@ -0,0 +1,132 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
// DefaultEndpoint default API endpoint.
const DefaultEndpoint = "https://api.autodns.com/v1/"
// DefaultEndpointContext default API endpoint context.
const DefaultEndpointContext int = 4
// Client the Autodns API client.
type Client struct {
username string
password string
context int
BaseURL *url.URL
HTTPClient *http.Client
}
// NewClient creates a new Client.
func NewClient(username string, password string, clientContext int) *Client {
baseURL, _ := url.Parse(DefaultEndpoint)
return &Client{
username: username,
password: password,
context: clientContext,
BaseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}
}
// AddTxtRecords adds TXT records.
func (c *Client) AddTxtRecords(ctx context.Context, domain string, records []*ResourceRecord) (*Zone, error) {
zoneStream := &ZoneStream{Adds: records}
return c.updateZone(ctx, domain, zoneStream)
}
// RemoveTXTRecords removes TXT records.
func (c *Client) RemoveTXTRecords(ctx context.Context, domain string, records []*ResourceRecord) error {
zoneStream := &ZoneStream{Removes: records}
_, err := c.updateZone(ctx, domain, zoneStream)
return err
}
// https://github.com/InterNetX/domainrobot-api/blob/bdc8fe92a2f32fcbdb29e30bf6006ab446f81223/src/domainrobot.json#L21090
func (c *Client) updateZone(ctx context.Context, domain string, zoneStream *ZoneStream) (*Zone, error) {
endpoint := c.BaseURL.JoinPath("zone", domain, "_stream")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, zoneStream)
if err != nil {
return nil, err
}
var zone *Zone
if err := c.do(req, &zone); err != nil {
return nil, err
}
return zone, nil
}
func (c *Client) do(req *http.Request, result any) error {
req.Header.Set("X-Domainrobot-Context", strconv.Itoa(c.context))
req.SetBasicAuth(c.username, c.password)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View file

@ -0,0 +1,96 @@
package internal
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, method, pattern string, status int, file string) *Client {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest)
return
}
apiUser, apiKey, ok := req.BasicAuth()
if apiUser != "user" || apiKey != "secret" || !ok {
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if file == "" {
rw.WriteHeader(status)
return
}
open, err := os.Open(filepath.Join("fixtures", file))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = open.Close() }()
rw.WriteHeader(status)
_, err = io.Copy(rw, open)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
client := NewClient("user", "secret", 123)
client.HTTPClient = server.Client()
client.BaseURL, _ = url.Parse(server.URL)
return client
}
func TestClient_AddTxtRecords(t *testing.T) {
client := setupTest(t, http.MethodPost, "/zone/example.com/_stream", http.StatusOK, "add-record.json")
records := []*ResourceRecord{{}}
zone, err := client.AddTxtRecords(context.Background(), "example.com", records)
require.NoError(t, err)
expected := &Zone{
Name: "example.com",
ResourceRecords: []*ResourceRecord{{
Name: "example.com",
TTL: 120,
Type: "TXT",
Value: "txt",
Pref: 1,
}},
Action: "xxx",
VirtualNameServer: "yyy",
}
assert.Equal(t, expected, zone)
}
func TestClient_RemoveTXTRecords(t *testing.T) {
client := setupTest(t, http.MethodPost, "/zone/example.com/_stream", http.StatusOK, "add-record.json")
records := []*ResourceRecord{{}}
err := client.RemoveTXTRecords(context.Background(), "example.com", records)
require.NoError(t, err)
}

View file

@ -0,0 +1,14 @@
{
"origin": "example.com",
"resourceRecords": [
{
"name": "example.com",
"ttl": 120,
"type": "TXT",
"value": "txt",
"pref": 1
}
],
"action": "xxx",
"virtualNameServer": "yyy"
}

View file

@ -0,0 +1,14 @@
{
"origin": "example.com",
"resourceRecords": [
{
"name": "example.com",
"ttl": 120,
"type": "TXT",
"value": "txt",
"pref": 1
}
],
"action": "xxx",
"virtualNameServer": "yyy"
}

View file

@ -0,0 +1,57 @@
package internal
type ResponseMessage struct {
Text string `json:"text"`
Messages []string `json:"messages"`
Objects []string `json:"objects"`
Code string `json:"code"`
Status string `json:"status"`
}
type ResponseStatus struct {
Code string `json:"code"`
Text string `json:"text"`
Type string `json:"type"`
}
type ResponseObject struct {
Type string `json:"type"`
Value string `json:"value"`
Summary int32 `json:"summary"`
Data string
}
type DataZoneResponse struct {
STID string `json:"stid"`
CTID string `json:"ctid"`
Messages []*ResponseMessage `json:"messages"`
Status *ResponseStatus `json:"status"`
Object any `json:"object"`
Data []*Zone `json:"data"`
}
// ResourceRecord holds a resource record.
// https://help.internetx.com/display/APIXMLEN/Resource+Record+Object
type ResourceRecord struct {
Name string `json:"name"`
TTL int64 `json:"ttl"`
Type string `json:"type"`
Value string `json:"value"`
Pref int32 `json:"pref,omitempty"`
}
// Zone is an autodns zone record with all for us relevant fields.
// https://help.internetx.com/display/APIXMLEN/Zone+Object
type Zone struct {
Name string `json:"origin"`
ResourceRecords []*ResourceRecord `json:"resourceRecords"`
Action string `json:"action"`
VirtualNameServer string `json:"virtualNameServer"`
}
// ZoneStream body of the requests.
// https://github.com/InterNetX/domainrobot-api/blob/bdc8fe92a2f32fcbdb29e30bf6006ab446f81223/src/domainrobot.json#L35914-L35932
type ZoneStream struct {
Adds []*ResourceRecord `json:"adds"`
Removes []*ResourceRecord `json:"rems"`
}

View file

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/Azure/go-autorest/autorest"
@ -14,6 +15,7 @@ import (
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const defaultMetadataEndpoint = "http://169.254.169.254"
@ -122,7 +124,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
}
if config.HTTPClient == nil {
config.HTTPClient = http.DefaultClient
config.HTTPClient = &http.Client{Timeout: 5 * time.Second}
}
authorizer, err := getAuthorizer(config)
@ -208,8 +210,12 @@ func getMetadata(config *Config, field string) (string, error) {
metadataEndpoint = defaultMetadataEndpoint
}
resource := fmt.Sprintf("%s/metadata/instance/compute/%s", metadataEndpoint, field)
req, err := http.NewRequest(http.MethodGet, resource, nil)
endpoint, err := url.JoinPath(metadataEndpoint, "metadata", "instance", "compute", field)
if err != nil {
return "", err
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
return "", err
}
@ -223,14 +229,15 @@ func getMetadata(config *Config, field string) (string, error) {
resp, err := config.HTTPClient.Do(req)
if err != nil {
return "", err
return "", errutils.NewHTTPDoError(req, err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
return "", errutils.NewReadResponseError(req, resp.StatusCode, err)
}
return string(respBody), nil
return string(raw), nil
}

View file

@ -118,7 +118,7 @@ func (d *dnsProviderPrivate) getHostedZoneID(ctx context.Context, fqdn string) (
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err)
}
dc := privatedns.NewPrivateZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)

View file

@ -118,7 +118,7 @@ func (d *dnsProviderPublic) getHostedZoneID(ctx context.Context, fqdn string) (s
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err)
}
dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)

View file

@ -2,6 +2,7 @@
package bluecat
import (
"context"
"errors"
"fmt"
"net/http"
@ -97,7 +98,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("bluecat: credentials missing")
}
client := internal.NewClient(config.BaseURL)
client := internal.NewClient(config.BaseURL, config.UserName, config.Password)
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
@ -112,17 +113,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Login(d.config.UserName, d.config.Password)
ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil {
return fmt.Errorf("bluecat: login: %w", err)
}
viewID, err := d.client.LookupViewID(d.config.ConfigName, d.config.DNSView)
viewID, err := d.client.LookupViewID(ctx, d.config.ConfigName, d.config.DNSView)
if err != nil {
return fmt.Errorf("bluecat: lookupViewID: %w", err)
}
parentZoneID, name, err := d.client.LookupParentZoneID(viewID, info.EffectiveFQDN)
parentZoneID, name, err := d.client.LookupParentZoneID(ctx, viewID, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("bluecat: lookupParentZoneID: %w", err)
}
@ -137,17 +138,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
Properties: fmt.Sprintf("ttl=%d|absoluteName=%s|txt=%s|", d.config.TTL, info.EffectiveFQDN, info.Value),
}
_, err = d.client.AddEntity(parentZoneID, txtRecord)
_, err = d.client.AddEntity(ctx, parentZoneID, txtRecord)
if err != nil {
return fmt.Errorf("bluecat: add TXT record: %w", err)
}
err = d.client.Deploy(parentZoneID)
err = d.client.Deploy(ctx, parentZoneID)
if err != nil {
return fmt.Errorf("bluecat: deploy: %w", err)
}
err = d.client.Logout()
err = d.client.Logout(ctx)
if err != nil {
return fmt.Errorf("bluecat: logout: %w", err)
}
@ -159,37 +160,37 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Login(d.config.UserName, d.config.Password)
ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil {
return fmt.Errorf("bluecat: login: %w", err)
}
viewID, err := d.client.LookupViewID(d.config.ConfigName, d.config.DNSView)
viewID, err := d.client.LookupViewID(ctx, d.config.ConfigName, d.config.DNSView)
if err != nil {
return fmt.Errorf("bluecat: lookupViewID: %w", err)
}
parentZoneID, name, err := d.client.LookupParentZoneID(viewID, info.EffectiveFQDN)
parentZoneID, name, err := d.client.LookupParentZoneID(ctx, viewID, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("bluecat: lookupParentZoneID: %w", err)
}
txtRecord, err := d.client.GetEntityByName(parentZoneID, name, internal.TXTType)
txtRecord, err := d.client.GetEntityByName(ctx, parentZoneID, name, internal.TXTType)
if err != nil {
return fmt.Errorf("bluecat: get TXT record: %w", err)
}
err = d.client.Delete(txtRecord.ID)
err = d.client.Delete(ctx, txtRecord.ID)
if err != nil {
return fmt.Errorf("bluecat: delete TXT record: %w", err)
}
err = d.client.Deploy(parentZoneID)
err = d.client.Deploy(ctx, parentZoneID)
if err != nil {
return fmt.Errorf("bluecat: deploy: %w", err)
}
err = d.client.Logout()
err = d.client.Logout(ctx)
if err != nil {
return fmt.Errorf("bluecat: logout: %w", err)
}

View file

@ -2,14 +2,18 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
// Object types.
@ -20,153 +24,88 @@ const (
TXTType = "TXTRecord"
)
const authorizationHeader = "Authorization"
type Client struct {
HTTPClient *http.Client
username string
password string
baseURL string
token string
tokenExp *regexp.Regexp
baseURL *url.URL
HTTPClient *http.Client
}
func NewClient(baseURL string) *Client {
func NewClient(baseURL string, username, password string) *Client {
bu, _ := url.Parse(baseURL)
return &Client{
HTTPClient: &http.Client{Timeout: 30 * time.Second},
baseURL: baseURL,
username: username,
password: password,
tokenExp: regexp.MustCompile("BAMAuthToken: [^ ]+"),
baseURL: bu,
HTTPClient: &http.Client{Timeout: 30 * time.Second},
}
}
// Login Logs in as API user.
// Authenticates and receives a token to be used in for subsequent requests.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/login/9.1.0
func (c *Client) Login(username, password string) error {
queryArgs := map[string]string{
"username": username,
"password": password,
}
resp, err := c.sendRequest(http.MethodGet, "login", nil, queryArgs)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "login",
Message: string(data),
}
}
authBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
authResp := string(authBytes)
if strings.Contains(authResp, "Authentication Error") {
return fmt.Errorf("request failed: %s", strings.Trim(authResp, `"`))
}
// Upon success, API responds with "Session Token-> BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM= <- for User : username"
c.token = c.tokenExp.FindString(authResp)
return nil
}
// Logout Logs out of the current API session.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/logout/9.1.0
func (c *Client) Logout() error {
if c.token == "" {
// nothing to do
return nil
}
resp, err := c.sendRequest(http.MethodGet, "logout", nil, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "logout",
Message: string(data),
}
}
authBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
authResp := string(authBytes)
if !strings.Contains(authResp, "successfully") {
return fmt.Errorf("request failed to delete session: %s", strings.Trim(authResp, `"`))
}
c.token = ""
return nil
}
// Deploy the DNS config for the specified entity to the authoritative servers.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/POST/v1/quickDeploy/9.1.0
func (c *Client) Deploy(entityID uint) error {
queryArgs := map[string]string{
"entityId": strconv.FormatUint(uint64(entityID), 10),
}
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/POST/v1/quickDeploy/9.5.0
func (c *Client) Deploy(ctx context.Context, entityID uint) error {
endpoint := c.createEndpoint("quickDeploy")
resp, err := c.sendRequest(http.MethodPost, "quickDeploy", nil, queryArgs)
q := endpoint.Query()
q.Set("entityId", strconv.FormatUint(uint64(entityID), 10))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, nil)
if err != nil {
return err
}
defer resp.Body.Close()
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
// The API doc says that 201 is expected but in the reality 200 is return.
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "quickDeploy",
Message: string(data),
}
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
return nil
}
// AddEntity A generic method for adding configurations, DNS zones, and DNS resource records.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/POST/v1/addEntity/9.1.0
func (c *Client) AddEntity(parentID uint, entity Entity) (uint64, error) {
queryArgs := map[string]string{
"parentId": strconv.FormatUint(uint64(parentID), 10),
}
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/POST/v1/addEntity/9.5.0
func (c *Client) AddEntity(ctx context.Context, parentID uint, entity Entity) (uint64, error) {
endpoint := c.createEndpoint("addEntity")
resp, err := c.sendRequest(http.MethodPost, "addEntity", entity, queryArgs)
q := endpoint.Query()
q.Set("parentId", strconv.FormatUint(uint64(parentID), 10))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, entity)
if err != nil {
return 0, err
}
defer resp.Body.Close()
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return 0, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return 0, &APIError{
StatusCode: resp.StatusCode,
Resource: "addEntity",
Message: string(data),
}
return 0, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
addTxtBytes, _ := io.ReadAll(resp.Body)
raw, _ := io.ReadAll(resp.Body)
// addEntity responds only with body text containing the ID of the created record
addTxtResp := string(addTxtBytes)
addTxtResp := string(raw)
id, err := strconv.ParseUint(addTxtResp, 10, 64)
if err != nil {
return 0, fmt.Errorf("addEntity request failed: %s", addTxtResp)
@ -176,73 +115,84 @@ func (c *Client) AddEntity(parentID uint, entity Entity) (uint64, error) {
}
// GetEntityByName Returns objects from the database referenced by their database ID and with its properties fields populated.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/getEntityById/9.1.0
func (c *Client) GetEntityByName(parentID uint, name, objType string) (*EntityResponse, error) {
queryArgs := map[string]string{
"parentId": strconv.FormatUint(uint64(parentID), 10),
"name": name,
"type": objType,
}
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/getEntityById/9.5.0
func (c *Client) GetEntityByName(ctx context.Context, parentID uint, name, objType string) (*EntityResponse, error) {
endpoint := c.createEndpoint("getEntityByName")
resp, err := c.sendRequest(http.MethodGet, "getEntityByName", nil, queryArgs)
q := endpoint.Query()
q.Set("parentId", strconv.FormatUint(uint64(parentID), 10))
q.Set("name", name)
q.Set("type", objType)
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return nil, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return nil, &APIError{
StatusCode: resp.StatusCode,
Resource: "getEntityByName",
Message: string(data),
}
return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
var txtRec EntityResponse
if err = json.NewDecoder(resp.Body).Decode(&txtRec); err != nil {
return nil, fmt.Errorf("JSON decode: %w", err)
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errutils.NewReadResponseError(req, resp.StatusCode, err)
}
return &txtRec, nil
var entity EntityResponse
err = json.Unmarshal(raw, &entity)
if err != nil {
return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return &entity, nil
}
// Delete Deletes an object using the generic delete method.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/DELETE/v1/delete/9.1.0
func (c *Client) Delete(objectID uint) error {
queryArgs := map[string]string{
"objectId": strconv.FormatUint(uint64(objectID), 10),
}
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/DELETE/v1/delete/9.5.0
func (c *Client) Delete(ctx context.Context, objectID uint) error {
endpoint := c.createEndpoint("delete")
resp, err := c.sendRequest(http.MethodDelete, "delete", nil, queryArgs)
q := endpoint.Query()
q.Set("objectId", strconv.FormatUint(uint64(objectID), 10))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
defer resp.Body.Close()
// The API doc says that 204 is expected but in the reality 200 is return.
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "delete",
Message: string(data),
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
// The API doc says that 204 is expected but in the reality 200 is returned.
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
return nil
}
// LookupViewID Find the DNS view with the given name within.
func (c *Client) LookupViewID(configName, viewName string) (uint, error) {
func (c *Client) LookupViewID(ctx context.Context, configName, viewName string) (uint, error) {
// Lookup the entity ID of the configuration named in our properties.
conf, err := c.GetEntityByName(0, configName, ConfigType)
conf, err := c.GetEntityByName(ctx, 0, configName, ConfigType)
if err != nil {
return 0, err
}
view, err := c.GetEntityByName(conf.ID, viewName, ViewType)
view, err := c.GetEntityByName(ctx, conf.ID, viewName, ViewType)
if err != nil {
return 0, err
}
@ -252,7 +202,7 @@ func (c *Client) LookupViewID(configName, viewName string) (uint, error) {
// LookupParentZoneID Return the entityId of the parent zone by recursing from the root view.
// Also return the simple name of the host.
func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, error) {
func (c *Client) LookupParentZoneID(ctx context.Context, viewID uint, fqdn string) (uint, string, error) {
if fqdn == "" {
return viewID, "", nil
}
@ -263,7 +213,7 @@ func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, err
parentViewID := viewID
for i := len(zones) - 1; i > -1; i-- {
zone, err := c.GetEntityByName(parentViewID, zones[i], ZoneType)
zone, err := c.GetEntityByName(ctx, parentViewID, zones[i], ZoneType)
if err != nil {
return 0, "", fmt.Errorf("could not find zone named %s: %w", name, err)
}
@ -282,32 +232,39 @@ func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, err
return parentViewID, name, nil
}
// Send a REST request, using query parameters specified.
// The Authorization header will be set if we have an active auth token.
func (c *Client) sendRequest(method, resource string, payload interface{}, queryParams map[string]string) (*http.Response, error) {
url := fmt.Sprintf("%s/Services/REST/v1/%s", c.baseURL, resource)
body, err := json.Marshal(payload)
if err != nil {
return nil, err
func (c *Client) createEndpoint(resource string) *url.URL {
return c.baseURL.JoinPath("Services", "REST", "v1", resource)
}
req, err := http.NewRequest(method, url, bytes.NewReader(body))
if err != nil {
return nil, err
func (c *Client) doAuthenticated(ctx context.Context, req *http.Request) (*http.Response, error) {
tok := getToken(ctx)
if tok != "" {
req.Header.Set(authorizationHeader, tok)
}
req.Header.Set("Content-Type", "application/json")
if c.token != "" {
req.Header.Set("Authorization", c.token)
}
q := req.URL.Query()
for k, v := range queryParams {
q.Set(k, v)
}
req.URL.RawQuery = q.Encode()
return c.HTTPClient.Do(req)
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@ -15,7 +16,8 @@ func TestClient_LookupParentZoneID(t *testing.T) {
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient(server.URL)
client := NewClient(server.URL, "user", "secret")
client.HTTPClient = server.Client()
mux.HandleFunc("/Services/REST/v1/getEntityByName", func(rw http.ResponseWriter, req *http.Request) {
query := req.URL.Query()
@ -33,7 +35,7 @@ func TestClient_LookupParentZoneID(t *testing.T) {
http.Error(rw, "{}", http.StatusOK)
})
parentID, name, err := client.LookupParentZoneID(2, "foo.example.com")
parentID, name, err := client.LookupParentZoneID(context.Background(), 2, "foo.example.com")
require.NoError(t, err)
assert.EqualValues(t, 2, parentID)

View file

@ -0,0 +1,115 @@
package internal
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
type token string
const tokenKey token = "token"
// login Logs in as API user.
// Authenticates and receives a token to be used in for subsequent requests.
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/login/9.5.0
func (c *Client) login(ctx context.Context) (string, error) {
endpoint := c.createEndpoint("login")
q := endpoint.Query()
q.Set("username", c.username)
q.Set("password", c.password)
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return "", err
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return "", errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", errutils.NewReadResponseError(req, resp.StatusCode, err)
}
authResp := string(raw)
if strings.Contains(authResp, "Authentication Error") {
return "", fmt.Errorf("request failed: %s", strings.Trim(authResp, `"`))
}
// Upon success, API responds with "Session Token-> BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM= <- for User : username"
tok := c.tokenExp.FindString(authResp)
return tok, nil
}
// Logout Logs out of the current API session.
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/logout/9.5.0
func (c *Client) Logout(ctx context.Context) error {
if getToken(ctx) == "" {
// nothing to do
return nil
}
endpoint := c.createEndpoint("logout")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
authResp := string(raw)
if !strings.Contains(authResp, "successfully") {
return fmt.Errorf("request failed to delete session: %s", strings.Trim(authResp, `"`))
}
return nil
}
func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) {
tok, err := c.login(ctx)
if err != nil {
return nil, err
}
return context.WithValue(ctx, tokenKey, tok), nil
}
func getToken(ctx context.Context) string {
tok, ok := ctx.Value(tokenKey).(string)
if !ok {
return ""
}
return tok
}

View file

@ -0,0 +1,59 @@
package internal
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const fakeToken = "BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM="
func TestClient_CreateAuthenticatedContext(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient(server.URL, "user", "secret")
client.HTTPClient = server.Client()
mux.HandleFunc("/Services/REST/v1/login", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
query := req.URL.Query()
if query.Get("username") != "user" {
http.Error(rw, fmt.Sprintf("invalid username %s", query.Get("username")), http.StatusUnauthorized)
return
}
if query.Get("password") != "secret" {
http.Error(rw, fmt.Sprintf("invalid password %s", query.Get("password")), http.StatusUnauthorized)
return
}
_, _ = fmt.Fprint(rw, fakeToken)
})
mux.HandleFunc("/Services/REST/v1/delete", func(rw http.ResponseWriter, req *http.Request) {
authorization := req.Header.Get(authorizationHeader)
if authorization != fakeToken {
http.Error(rw, fmt.Sprintf("invalid credential: %s", authorization), http.StatusUnauthorized)
return
}
})
ctx, err := client.CreateAuthenticatedContext(context.Background())
require.NoError(t, err)
at := getToken(ctx)
assert.Equal(t, fakeToken, at)
err = client.Delete(ctx, 123)
require.NoError(t, err)
}

View file

@ -1,7 +1,5 @@
package internal
import "fmt"
// Entity JSON body for Bluecat entity requests.
type Entity struct {
ID string `json:"id,omitempty"`
@ -17,13 +15,3 @@ type EntityResponse struct {
Type string `json:"type"`
Properties string `json:"properties"`
}
type APIError struct {
StatusCode int
Resource string
Message string
}
func (a APIError) Error() string {
return fmt.Sprintf("resource: %s, status code: %d, message: %s", a.Resource, a.StatusCode, a.Message)
}

View file

@ -1,9 +1,11 @@
package brandit
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"sync"
"time"
@ -12,8 +14,6 @@ import (
"github.com/go-acme/lego/v4/providers/dns/brandit/internal"
)
const defaultTTL = 600
// Environment variables names.
const (
envNamespace = "BRANDIT_"
@ -25,7 +25,6 @@ const (
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
DefaultBrandItPropagationTimeout = 600 * time.Second
)
// Config is used to configure the creation of the DNSProvider.
@ -42,8 +41,8 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
return &Config{
TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, DefaultBrandItPropagationTimeout),
TTL: env.GetOrDefaultInt(EnvTTL, 600),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 10*time.Minute),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval),
HTTPClient: &http.Client{
Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second),
@ -97,13 +96,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
}, nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}
// Present creates a TXT record using the specified parameters.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("brandit: %w", err)
return fmt.Errorf("brandit: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
@ -111,6 +116,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
return fmt.Errorf("brandit: %w", err)
}
ctx := context.Background()
record := internal.Record{
Type: "TXT",
Name: subDomain,
@ -119,18 +126,18 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
}
// find the account associated with the domain
account, err := d.client.StatusDomain(dns01.UnFqdn(authZone))
account, err := d.client.StatusDomain(ctx, dns01.UnFqdn(authZone))
if err != nil {
return fmt.Errorf("brandit: status domain: %w", err)
}
// Find the next record id
recordID, err := d.client.ListRecords(account.Response.Registrar[0], dns01.UnFqdn(authZone))
recordID, err := d.client.ListRecords(ctx, account.Registrar[0], dns01.UnFqdn(authZone))
if err != nil {
return fmt.Errorf("brandit: list records: %w", err)
}
result, err := d.client.AddRecord(dns01.UnFqdn(authZone), account.Response.Registrar[0], fmt.Sprint(recordID.Response.Total[0]), record)
result, err := d.client.AddRecord(ctx, dns01.UnFqdn(authZone), account.Registrar[0], strconv.Itoa(recordID.Total[0]), record)
if err != nil {
return fmt.Errorf("brandit: add record: %w", err)
}
@ -148,7 +155,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("brandit: %w", err)
return fmt.Errorf("brandit: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
// gets the record's unique ID
@ -159,25 +166,27 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("brandit: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
}
ctx := context.Background()
// find the account associated with the domain
account, err := d.client.StatusDomain(dns01.UnFqdn(authZone))
account, err := d.client.StatusDomain(ctx, dns01.UnFqdn(authZone))
if err != nil {
return fmt.Errorf("brandit: status domain: %w", err)
}
records, err := d.client.ListRecords(account.Response.Registrar[0], dns01.UnFqdn(authZone))
records, err := d.client.ListRecords(ctx, account.Registrar[0], dns01.UnFqdn(authZone))
if err != nil {
return fmt.Errorf("brandit: list records: %w", err)
}
var recordID int
for i, r := range records.Response.RR {
for i, r := range records.RR {
if r == dnsRecord {
recordID = i
}
}
_, err = d.client.DeleteRecord(dns01.UnFqdn(authZone), account.Response.Registrar[0], dnsRecord, fmt.Sprint(recordID))
err = d.client.DeleteRecord(ctx, dns01.UnFqdn(authZone), account.Registrar[0], dnsRecord, strconv.Itoa(recordID))
if err != nil {
return fmt.Errorf("brandit: delete record: %w", err)
}
@ -189,9 +198,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}

View file

@ -1,4 +1,4 @@
Name = "BRANDIT"
Name = "Brandit"
Description = ''''''
URL = "https://www.brandit.com/"
Code = "brandit"

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
@ -12,6 +13,8 @@ import (
"net/url"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const defaultBaseURL = "https://portal.brandit.com/api/v3/"
@ -20,7 +23,8 @@ const defaultBaseURL = "https://portal.brandit.com/api/v3/"
type Client struct {
apiUsername string
apiKey string
BaseURL string
baseURL string
HTTPClient *http.Client
}
@ -33,70 +37,69 @@ func NewClient(apiUsername, apiKey string) (*Client, error) {
return &Client{
apiUsername: apiUsername,
apiKey: apiKey,
BaseURL: defaultBaseURL,
baseURL: defaultBaseURL,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil
}
// ListRecords lists all records.
// https://portal.brandit.com/apidocv3#listDNSRR
func (c *Client) ListRecords(account, dnsZone string) (*ListRecords, error) {
// Create a new query
func (c *Client) ListRecords(ctx context.Context, account, dnsZone string) (*ListRecordsResponse, error) {
query := url.Values{}
query.Add("command", "listDNSRR")
query.Add("account", account)
query.Add("dnszone", dnsZone)
result := &ListRecords{}
result := &Response[*ListRecordsResponse]{}
err := c.do(query, result)
err := c.do(ctx, query, result)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
return nil, err
}
for len(result.Response.RR) < result.Response.Total[0] {
query.Add("first", fmt.Sprint(result.Response.Last[0]+1))
tmp := &ListRecords{}
err := c.do(query, tmp)
tmp := &Response[*ListRecordsResponse]{}
err := c.do(ctx, query, tmp)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
return nil, err
}
result.Response.RR = append(result.Response.RR, tmp.Response.RR...)
result.Response.Last = tmp.Response.Last
}
return result, nil
return result.Response, nil
}
// AddRecord adds a DNS record.
// https://portal.brandit.com/apidocv3#addDNSRR
func (c *Client) AddRecord(domainName, account, newRecordID string, record Record) (*AddRecord, error) {
// Create a new query
func (c *Client) AddRecord(ctx context.Context, domainName, account, newRecordID string, record Record) (*AddRecord, error) {
value := strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ")
query := url.Values{}
query.Add("command", "addDNSRR")
query.Add("account", account)
query.Add("dnszone", domainName)
query.Add("rrdata", strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " "))
query.Add("rrdata", value)
query.Add("key", newRecordID)
result := &AddRecord{}
err := c.do(query, result)
err := c.do(ctx, query, result)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
return nil, err
}
result.Record = strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ")
result.Record = value
return result, nil
}
// DeleteRecord deletes a DNS record.
// https://portal.brandit.com/apidocv3#deleteDNSRR
func (c *Client) DeleteRecord(domainName, account, dnsRecord, recordID string) (*DeleteRecord, error) {
// Create a new query
func (c *Client) DeleteRecord(ctx context.Context, domainName, account, dnsRecord, recordID string) error {
query := url.Values{}
query.Add("command", "deleteDNSRR")
query.Add("account", account)
@ -104,68 +107,70 @@ func (c *Client) DeleteRecord(domainName, account, dnsRecord, recordID string) (
query.Add("rrdata", dnsRecord)
query.Add("key", recordID)
result := &DeleteRecord{}
err := c.do(query, result)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
}
return result, nil
return c.do(ctx, query, nil)
}
// StatusDomain returns the status of a domain and account associated with it.
// https://portal.brandit.com/apidocv3#statusDomain
func (c *Client) StatusDomain(domain string) (*StatusDomain, error) {
// Create a new query
func (c *Client) StatusDomain(ctx context.Context, domain string) (*StatusResponse, error) {
query := url.Values{}
query.Add("command", "statusDomain")
query.Add("domain", domain)
result := &StatusDomain{}
result := &Response[*StatusResponse]{}
err := c.do(query, result)
err := c.do(ctx, query, result)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
return nil, err
}
return result, nil
return result.Response, nil
}
func (c *Client) do(query url.Values, result any) error {
// Add signature
v, err := sign(c.apiUsername, c.apiKey, query)
if err != nil {
return fmt.Errorf("signature: %w", err)
}
resp, err := c.HTTPClient.PostForm(c.BaseURL, v)
func (c *Client) do(ctx context.Context, query url.Values, result any) error {
values, err := sign(c.apiUsername, c.apiKey, query)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, strings.NewReader(values.Encode()))
if err != nil {
return fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response body: %w", err)
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
// Unmarshal the error response, because the API returns a 200 OK even if there is an error.
var apiError APIError
err = json.Unmarshal(raw, &apiError)
if err != nil {
return fmt.Errorf("unmarshal error response: %w %s", err, string(raw))
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
if apiError.Code > 299 || apiError.Status != "success" {
return apiError
}
if result == nil {
return nil
}
err = json.Unmarshal(raw, result)
if err != nil {
return fmt.Errorf("unmarshal response body: %w %s", err, string(raw))
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil

View file

@ -1,30 +1,32 @@
package internal
import (
"context"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, file string) *Client {
func setupTest(t *testing.T, filename string) *Client {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
open, err := os.Open(file)
file, err := os.Open(filepath.Join("fixtures", filename))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = open.Close() }()
defer func() { _ = file.Close() }()
rw.WriteHeader(http.StatusOK)
_, err = io.Copy(rw, open)
_, err = io.Copy(rw, file)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
@ -36,19 +38,18 @@ func setupTest(t *testing.T, file string) *Client {
require.NoError(t, err)
client.HTTPClient = server.Client()
client.BaseURL = server.URL
client.baseURL = server.URL
return client
}
func TestClient_StatusDomain(t *testing.T) {
client := setupTest(t, "./fixtures/status-domain.json")
client := setupTest(t, "status-domain.json")
domain, err := client.StatusDomain("example.com")
domain, err := client.StatusDomain(context.Background(), "example.com")
require.NoError(t, err)
expected := &StatusDomain{
Response: StatusResponse{
expected := &StatusResponse{
RenewalMode: []string{"DEFAULT"},
Status: []string{"clientTransferProhibited"},
TransferLock: []int{1},
@ -73,23 +74,25 @@ func TestClient_StatusDomain(t *testing.T) {
OwnerContact: []string{"example"},
CreatedBy: []string{"example"},
TransferMode: []string{"auto"},
},
Code: 200,
Status: "success",
Error: "",
}
assert.Equal(t, expected, domain)
}
func TestClient_ListRecords(t *testing.T) {
client := setupTest(t, "./fixtures/list-records.json")
func TestClient_StatusDomain_error(t *testing.T) {
client := setupTest(t, "error.json")
resp, err := client.ListRecords("example", "example.com")
_, err := client.StatusDomain(context.Background(), "example.com")
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}
func TestClient_ListRecords(t *testing.T) {
client := setupTest(t, "list-records.json")
resp, err := client.ListRecords(context.Background(), "example", "example.com")
require.NoError(t, err)
expected := &ListRecords{
Response: ListRecordsResponse{
expected := &ListRecordsResponse{
Limit: []int{100},
Column: []string{"rr"},
Count: []int{1},
@ -97,17 +100,20 @@ func TestClient_ListRecords(t *testing.T) {
Total: []int{1},
RR: []string{"example.com. 600 IN TXT txttxttxt"},
Last: []int{0},
},
Code: 200,
Status: "success",
Error: "",
}
assert.Equal(t, expected, resp)
}
func TestClient_ListRecords_error(t *testing.T) {
client := setupTest(t, "error.json")
_, err := client.ListRecords(context.Background(), "example", "example.com")
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}
func TestClient_AddRecord(t *testing.T) {
client := setupTest(t, "./fixtures/add-record.json")
client := setupTest(t, "add-record.json")
testRecord := Record{
ID: 2565,
@ -116,7 +122,7 @@ func TestClient_AddRecord(t *testing.T) {
Content: "txttxttxt",
TTL: 600,
}
resp, err := client.AddRecord("example.com", "test", "2565", testRecord)
resp, err := client.AddRecord(context.Background(), "example.com", "test", "2565", testRecord)
require.NoError(t, err)
expected := &AddRecord{
@ -133,17 +139,31 @@ func TestClient_AddRecord(t *testing.T) {
assert.Equal(t, expected, resp)
}
func TestClient_AddRecord_error(t *testing.T) {
client := setupTest(t, "error.json")
testRecord := Record{
ID: 2565,
Type: "TXT",
Name: "example.com",
Content: "txttxttxt",
TTL: 600,
}
_, err := client.AddRecord(context.Background(), "example.com", "test", "2565", testRecord)
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}
func TestClient_DeleteRecord(t *testing.T) {
client := setupTest(t, "./fixtures/delete-record.json")
client := setupTest(t, "delete-record.json")
resp, err := client.DeleteRecord("example.com", "test", "example.com 600 IN TXT txttxttxt", "2374")
err := client.DeleteRecord(context.Background(), "example.com", "test", "example.com 600 IN TXT txttxttxt", "2374")
require.NoError(t, err)
expected := &DeleteRecord{
Code: 200,
Status: "success",
Error: "",
}
assert.Equal(t, expected, resp)
func TestClient_DeleteRecord_error(t *testing.T) {
client := setupTest(t, "error.json")
err := client.DeleteRecord(context.Background(), "example.com", "test", "example.com 600 IN TXT txttxttxt", "2374")
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}

View file

@ -0,0 +1,5 @@
{
"code": 402,
"status": "error",
"error": "Invalid user."
}

View file

@ -2,8 +2,8 @@ package internal
import "fmt"
type StatusDomain struct {
Response StatusResponse `json:"response,omitempty"`
type Response[T any] struct {
Response T `json:"response,omitempty"`
Code int `json:"code"`
Status string `json:"status"`
Error string `json:"error"`
@ -36,13 +36,6 @@ type StatusResponse struct {
TransferMode []string `json:"transfermode"`
}
type ListRecords struct {
Response ListRecordsResponse `json:"response,omitempty"`
Code int `json:"code"`
Status string `json:"status"`
Error string `json:"error"`
}
type ListRecordsResponse struct {
Limit []int `json:"limit,omitempty"`
Column []string `json:"column,omitempty"`
@ -83,9 +76,3 @@ type Record struct {
Content string `json:"content,omitempty"`
TTL int `json:"ttl,omitempty"` // default 600
}
type DeleteRecord struct {
Code int `json:"code"`
Status string `json:"status"`
Error string `json:"error"`
}

View file

@ -2,15 +2,16 @@
package checkdomain
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/checkdomain/internal"
)
// Environment variables names.
@ -26,11 +27,6 @@ const (
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
)
const (
defaultEndpoint = "https://api.checkdomain.de"
defaultTTL = 300
)
// Config is used to configure the creation of the DNSProvider.
type Config struct {
Endpoint *url.URL
@ -44,7 +40,7 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
return &Config{
TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL),
TTL: env.GetOrDefaultInt(EnvTTL, 300),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 5*time.Minute),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 7*time.Second),
HTTPClient: &http.Client{
@ -56,9 +52,7 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
domainIDMu sync.Mutex
domainIDMapping map[string]int
client *internal.Client
}
// NewDNSProvider returns a DNSProvider instance configured for CheckDomain.
@ -71,7 +65,7 @@ func NewDNSProvider() (*DNSProvider, error) {
config := NewDefaultConfig()
config.Token = values[EnvToken]
endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, defaultEndpoint))
endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, internal.DefaultEndpoint))
if err != nil {
return nil, fmt.Errorf("checkdomain: invalid %s: %w", EnvEndpoint, err)
}
@ -89,32 +83,33 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("checkdomain: missing token")
}
if config.HTTPClient == nil {
config.HTTPClient = http.DefaultClient
client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.Token))
if config.Endpoint != nil {
client.BaseURL = config.Endpoint
}
return &DNSProvider{
config: config,
domainIDMapping: make(map[string]int),
}, nil
return &DNSProvider{config: config, client: client}, nil
}
// Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
ctx := context.Background()
// TODO(ldez) replace domain by FQDN to follow CNAME.
domainID, err := d.getDomainIDByName(domain)
domainID, err := d.client.GetDomainIDByName(ctx, domain)
if err != nil {
return fmt.Errorf("checkdomain: %w", err)
}
err = d.checkNameservers(domainID)
err = d.client.CheckNameservers(ctx, domainID)
if err != nil {
return fmt.Errorf("checkdomain: %w", err)
}
info := dns01.GetChallengeInfo(domain, keyAuth)
err = d.createRecord(domainID, &Record{
err = d.client.CreateRecord(ctx, domainID, &internal.Record{
Name: info.EffectiveFQDN,
TTL: d.config.TTL,
Type: "TXT",
@ -130,28 +125,28 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
// CleanUp removes the TXT record previously created.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
ctx := context.Background()
// TODO(ldez) replace domain by FQDN to follow CNAME.
domainID, err := d.getDomainIDByName(domain)
domainID, err := d.client.GetDomainIDByName(ctx, domain)
if err != nil {
return fmt.Errorf("checkdomain: %w", err)
}
err = d.checkNameservers(domainID)
err = d.client.CheckNameservers(ctx, domainID)
if err != nil {
return fmt.Errorf("checkdomain: %w", err)
}
info := dns01.GetChallengeInfo(domain, keyAuth)
err = d.deleteTXTRecord(domainID, info.EffectiveFQDN, info.Value)
defer d.client.CleanCache(info.EffectiveFQDN)
err = d.client.DeleteTXTRecord(ctx, domainID, info.EffectiveFQDN, info.Value)
if err != nil {
return fmt.Errorf("checkdomain: %w", err)
}
d.domainIDMu.Lock()
delete(d.domainIDMapping, info.EffectiveFQDN)
d.domainIDMu.Unlock()
return nil
}

View file

@ -5,6 +5,7 @@ import (
"testing"
"github.com/go-acme/lego/v4/platform/tester"
"github.com/go-acme/lego/v4/providers/dns/checkdomain/internal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -83,7 +84,7 @@ func TestNewDNSProviderConfig(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
config := NewDefaultConfig()
config.Endpoint, _ = url.Parse(defaultEndpoint)
config.Endpoint, _ = url.Parse(internal.DefaultEndpoint)
if test.token != "" {
config.Token = test.token

View file

@ -1,416 +0,0 @@
package checkdomain
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
)
const (
ns1 = "ns.checkdomain.de"
ns2 = "ns2.checkdomain.de"
)
const domainNotFound = -1
// max page limit that the checkdomain api allows.
const maxLimit = 100
// max integer value.
const maxInt = int((^uint(0)) >> 1)
type (
// Some fields have been omitted from the structs
// because they are not required for this application.
DomainListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedDomainList `json:"_embedded"`
}
EmbeddedDomainList struct {
Domains []*Domain `json:"domains"`
}
Domain struct {
ID int `json:"id"`
Name string `json:"name"`
}
DomainResponse struct {
ID int `json:"id"`
Name string `json:"name"`
Created string `json:"created"`
PaidUp string `json:"payed_up"`
Active bool `json:"active"`
}
NameserverResponse struct {
General NameserverGeneral `json:"general"`
Nameservers []*Nameserver `json:"nameservers"`
SOA NameserverSOA `json:"soa"`
}
NameserverGeneral struct {
IPv4 string `json:"ip_v4"`
IPv6 string `json:"ip_v6"`
IncludeWWW bool `json:"include_www"`
}
NameserverSOA struct {
Mail string `json:"mail"`
Refresh int `json:"refresh"`
Retry int `json:"retry"`
Expiry int `json:"expiry"`
TTL int `json:"ttl"`
}
Nameserver struct {
Name string `json:"name"`
}
RecordListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedRecordList `json:"_embedded"`
}
EmbeddedRecordList struct {
Records []*Record `json:"records"`
}
Record struct {
Name string `json:"name"`
Value string `json:"value"`
TTL int `json:"ttl"`
Priority int `json:"priority"`
Type string `json:"type"`
}
)
func (d *DNSProvider) getDomainIDByName(name string) (int, error) {
// Load from cache if exists
d.domainIDMu.Lock()
id, ok := d.domainIDMapping[name]
d.domainIDMu.Unlock()
if ok {
return id, nil
}
// Find out by querying API
domains, err := d.listDomains()
if err != nil {
return domainNotFound, err
}
// Linear search over all registered domains
for _, domain := range domains {
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
d.domainIDMu.Lock()
d.domainIDMapping[name] = domain.ID
d.domainIDMu.Unlock()
return domain.ID, nil
}
}
return domainNotFound, errors.New("domain not found")
}
func (d *DNSProvider) listDomains() ([]*Domain, error) {
req, err := d.makeRequest(http.MethodGet, "/v1/domains", http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
// But that functionality is kinda broken,
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
q := req.URL.Query()
q.Set("limit", strconv.Itoa(maxLimit))
currentPage := 1
totalPages := maxInt
var domainList []*Domain
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
req.URL.RawQuery = q.Encode()
var res DomainListingResponse
if err := d.sendRequest(req, &res); err != nil {
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
}
// This is the first response,
// so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
domainList = make([]*Domain, 0, res.Total)
}
domainList = append(domainList, res.Embedded.Domains...)
currentPage++
}
return domainList, nil
}
func (d *DNSProvider) getNameserverInfo(domainID int) (*NameserverResponse, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers", domainID), http.NoBody)
if err != nil {
return nil, err
}
res := &NameserverResponse{}
if err := d.sendRequest(req, res); err != nil {
return nil, err
}
return res, nil
}
func (d *DNSProvider) checkNameservers(domainID int) error {
info, err := d.getNameserverInfo(domainID)
if err != nil {
return err
}
var found1, found2 bool
for _, item := range info.Nameservers {
switch item.Name {
case ns1:
found1 = true
case ns2:
found2 = true
}
}
if !found1 || !found2 {
return errors.New("not using checkdomain nameservers, can not update records")
}
return nil
}
func (d *DNSProvider) createRecord(domainID int, record *Record) error {
bs, err := json.Marshal(record)
if err != nil {
return fmt.Errorf("encoding record failed: %w", err)
}
req, err := d.makeRequest(http.MethodPost, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
if err != nil {
return err
}
return d.sendRequest(req, nil)
}
// Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
// TODO: Simplify this function once Checkdomain do provide the functionality.
func (d *DNSProvider) deleteTXTRecord(domainID int, recordName, recordValue string) error {
domainInfo, err := d.getDomainInfo(domainID)
if err != nil {
return err
}
nsInfo, err := d.getNameserverInfo(domainID)
if err != nil {
return err
}
allRecords, err := d.listRecords(domainID, "")
if err != nil {
return err
}
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
var recordsToKeep []*Record
// Find and delete matching records
for _, record := range allRecords {
if skipRecord(recordName, recordValue, record, nsInfo) {
continue
}
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
// The API Call to replace the records would fail if we wouldn't specify a value.
// Thus, we use the default TTL queried beforehand
if record.TTL == 0 {
record.TTL = nsInfo.SOA.TTL
}
recordsToKeep = append(recordsToKeep, record)
}
return d.replaceRecords(domainID, recordsToKeep)
}
func (d *DNSProvider) getDomainInfo(domainID int) (*DomainResponse, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d", domainID), http.NoBody)
if err != nil {
return nil, err
}
var res DomainResponse
err = d.sendRequest(req, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (d *DNSProvider) listRecords(domainID int, recordType string) ([]*Record, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
q := req.URL.Query()
q.Set("limit", strconv.Itoa(maxLimit))
if recordType != "" {
q.Set("type", recordType)
}
currentPage := 1
totalPages := maxInt
var recordList []*Record
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
req.URL.RawQuery = q.Encode()
var res RecordListingResponse
if err := d.sendRequest(req, &res); err != nil {
return nil, fmt.Errorf("failed to send record listing request: %w", err)
}
// This is the first response, so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
recordList = make([]*Record, 0, res.Total)
}
recordList = append(recordList, res.Embedded.Records...)
currentPage++
}
return recordList, nil
}
func (d *DNSProvider) replaceRecords(domainID int, records []*Record) error {
bs, err := json.Marshal(records)
if err != nil {
return fmt.Errorf("encoding record failed: %w", err)
}
req, err := d.makeRequest(http.MethodPut, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
if err != nil {
return err
}
return d.sendRequest(req, nil)
}
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
// Skip empty records
if record.Value == "" {
return true
}
// Skip some special records, otherwise we would get a "Nameserver update failed"
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
return true
}
nameMatch := recordName == "" || record.Name == recordName
valueMatch := recordValue == "" || record.Value == recordValue
// Skip our matching record
if record.Type == "TXT" && nameMatch && valueMatch {
return true
}
return false
}
func (d *DNSProvider) makeRequest(method, resource string, body io.Reader) (*http.Request, error) {
uri, err := d.config.Endpoint.Parse(resource)
if err != nil {
return nil, err
}
req, err := http.NewRequest(method, uri.String(), body)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+d.config.Token)
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}
func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error {
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
if err = checkResponse(resp); err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(raw, result)
if err != nil {
return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw))
}
return nil
}
func checkResponse(resp *http.Response) error {
if resp.StatusCode < http.StatusBadRequest {
return nil
}
if resp.Body == nil {
return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err)
}
return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw))
}

View file

@ -0,0 +1,383 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"golang.org/x/oauth2"
)
const (
ns1 = "ns.checkdomain.de"
ns2 = "ns2.checkdomain.de"
)
// DefaultEndpoint the default API endpoint.
const DefaultEndpoint = "https://api.checkdomain.de"
const domainNotFound = -1
// max page limit that the checkdomain api allows.
const maxLimit = 100
// max integer value.
const maxInt = int((^uint(0)) >> 1)
// Client the Autodns API client.
type Client struct {
domainIDMapping map[string]int
domainIDMu sync.Mutex
BaseURL *url.URL
httpClient *http.Client
}
// NewClient creates a new Client.
func NewClient(hc *http.Client) *Client {
baseURL, _ := url.Parse(DefaultEndpoint)
if hc == nil {
hc = &http.Client{Timeout: 10 * time.Second}
}
return &Client{
BaseURL: baseURL,
httpClient: hc,
domainIDMapping: make(map[string]int),
}
}
func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) {
// Load from cache if exists
c.domainIDMu.Lock()
id, ok := c.domainIDMapping[name]
c.domainIDMu.Unlock()
if ok {
return id, nil
}
// Find out by querying API
domains, err := c.listDomains(ctx)
if err != nil {
return domainNotFound, err
}
// Linear search over all registered domains
for _, domain := range domains {
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
c.domainIDMu.Lock()
c.domainIDMapping[name] = domain.ID
c.domainIDMu.Unlock()
return domain.ID, nil
}
}
return domainNotFound, errors.New("domain not found")
}
func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains")
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
// But that functionality is kinda broken,
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
q := endpoint.Query()
q.Set("limit", strconv.Itoa(maxLimit))
currentPage := 1
totalPages := maxInt
var domainList []*Domain
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
var res DomainListingResponse
if err := c.do(req, &res); err != nil {
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
}
// This is the first response,
// so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
domainList = make([]*Domain, 0, res.Total)
}
domainList = append(domainList, res.Embedded.Domains...)
currentPage++
}
return domainList, nil
}
func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
res := &NameserverResponse{}
if err := c.do(req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) CheckNameservers(ctx context.Context, domainID int) error {
info, err := c.getNameserverInfo(ctx, domainID)
if err != nil {
return err
}
var found1, found2 bool
for _, item := range info.Nameservers {
switch item.Name {
case ns1:
found1 = true
case ns2:
found2 = true
}
}
if !found1 || !found2 {
return errors.New("not using checkdomain nameservers, can not update records")
}
return nil
}
func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return err
}
return c.do(req, nil)
}
// DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
// TODO: Simplify this function once Checkdomain do provide the functionality.
func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error {
domainInfo, err := c.getDomainInfo(ctx, domainID)
if err != nil {
return err
}
nsInfo, err := c.getNameserverInfo(ctx, domainID)
if err != nil {
return err
}
allRecords, err := c.listRecords(ctx, domainID, "")
if err != nil {
return err
}
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
var recordsToKeep []*Record
// Find and delete matching records
for _, record := range allRecords {
if skipRecord(recordName, recordValue, record, nsInfo) {
continue
}
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
// The API Call to replace the records would fail if we wouldn't specify a value.
// Thus, we use the default TTL queried beforehand
if record.TTL == 0 {
record.TTL = nsInfo.SOA.TTL
}
recordsToKeep = append(recordsToKeep, record)
}
return c.replaceRecords(ctx, domainID, recordsToKeep)
}
func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID))
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
var res DomainResponse
err = c.do(req, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
q := endpoint.Query()
q.Set("limit", strconv.Itoa(maxLimit))
if recordType != "" {
q.Set("type", recordType)
}
currentPage := 1
totalPages := maxInt
var recordList []*Record
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
var res RecordListingResponse
if err := c.do(req, &res); err != nil {
return nil, fmt.Errorf("failed to send record listing request: %w", err)
}
// This is the first response, so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
recordList = make([]*Record, 0, res.Total)
}
recordList = append(recordList, res.Embedded.Records...)
currentPage++
}
return recordList, nil
}
func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records)
if err != nil {
return err
}
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func (c *Client) CleanCache(fqdn string) {
c.domainIDMu.Lock()
delete(c.domainIDMapping, fqdn)
c.domainIDMu.Unlock()
}
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
// Skip empty records
if record.Value == "" {
return true
}
// Skip some special records, otherwise we would get a "Nameserver update failed"
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
return true
}
nameMatch := recordName == "" || record.Name == recordName
valueMatch := recordValue == "" || record.Value == recordValue
// Skip our matching record
if record.Type == "TXT" && nameMatch && valueMatch {
return true
}
return false
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}
func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
if client == nil {
client = &http.Client{Timeout: 5 * time.Second}
}
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
Base: client.Transport,
}
return client
}

View file

@ -1,6 +1,8 @@
package checkdomain
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -15,32 +17,42 @@ import (
"github.com/stretchr/testify/require"
)
func setupTestProvider(t *testing.T) (*DNSProvider, *http.ServeMux) {
func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
config := NewDefaultConfig()
config.Endpoint, _ = url.Parse(server.URL)
config.Token = "secret"
client := NewClient(OAuthStaticAccessToken(server.Client(), "secret"))
client.BaseURL, _ = url.Parse(server.URL)
p, err := NewDNSProviderConfig(config)
require.NoError(t, err)
return p, mux
return client, mux
}
func Test_getDomainIDByName(t *testing.T) {
prd, handler := setupTestProvider(t)
func checkAuthorizationHeader(req *http.Request) error {
val := req.Header.Get("Authorization")
if val != "Bearer secret" {
return fmt.Errorf("invalid header value, got: %s want %s", val, "Bearer secret")
}
return nil
}
handler.HandleFunc("/v1/domains", func(rw http.ResponseWriter, req *http.Request) {
func TestClient_GetDomainIDByName(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/v1/domains", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return
}
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
domainList := DomainListingResponse{
Embedded: EmbeddedDomainList{Domains: []*Domain{
{ID: 1, Name: "test.com"},
@ -48,28 +60,34 @@ func Test_getDomainIDByName(t *testing.T) {
}},
}
err := json.NewEncoder(rw).Encode(domainList)
err = json.NewEncoder(rw).Encode(domainList)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
id, err := prd.getDomainIDByName("test.com")
id, err := client.GetDomainIDByName(context.Background(), "test.com")
require.NoError(t, err)
assert.Equal(t, 1, id)
}
func Test_checkNameservers(t *testing.T) {
prd, handler := setupTestProvider(t)
func TestClient_CheckNameservers(t *testing.T) {
client, mux := setupTest(t)
handler.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) {
mux.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return
}
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
nsResp := NameserverResponse{
Nameservers: []*Nameserver{
{Name: ns1},
@ -78,33 +96,39 @@ func Test_checkNameservers(t *testing.T) {
},
}
err := json.NewEncoder(rw).Encode(nsResp)
err = json.NewEncoder(rw).Encode(nsResp)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
err := prd.checkNameservers(1)
err := client.CheckNameservers(context.Background(), 1)
require.NoError(t, err)
}
func Test_createRecord(t *testing.T) {
prd, handler := setupTestProvider(t)
func TestClient_CreateRecord(t *testing.T) {
client, mux := setupTest(t)
handler.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) {
mux.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return
}
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
content, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if string(content) != `{"name":"test.com","value":"value","ttl":300,"priority":0,"type":"TXT"}` {
if string(bytes.TrimSpace(content)) != `{"name":"test.com","value":"value","ttl":300,"priority":0,"type":"TXT"}` {
http.Error(rw, "invalid request body: "+string(content), http.StatusBadRequest)
return
}
@ -117,12 +141,12 @@ func Test_createRecord(t *testing.T) {
Value: "value",
}
err := prd.createRecord(1, record)
err := client.CreateRecord(context.Background(), 1, record)
require.NoError(t, err)
}
func Test_deleteTXTRecord(t *testing.T) {
prd, handler := setupTestProvider(t)
func TestClient_DeleteTXTRecord(t *testing.T) {
client, mux := setupTest(t)
domainName := "lego.test"
recordValue := "test"
@ -158,20 +182,26 @@ func Test_deleteTXTRecord(t *testing.T) {
},
}
handler.HandleFunc("/v1/domains/1", func(rw http.ResponseWriter, req *http.Request) {
mux.HandleFunc("/v1/domains/1", func(rw http.ResponseWriter, req *http.Request) {
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
resp := DomainResponse{
ID: 1,
Name: domainName,
}
err := json.NewEncoder(rw).Encode(resp)
err = json.NewEncoder(rw).Encode(resp)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
handler.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) {
mux.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return
@ -188,7 +218,7 @@ func Test_deleteTXTRecord(t *testing.T) {
}
})
handler.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) {
mux.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodGet:
resp := RecordListingResponse{
@ -226,6 +256,6 @@ func Test_deleteTXTRecord(t *testing.T) {
})
info := dns01.GetChallengeInfo(domainName, "abc")
err := prd.deleteTXTRecord(1, info.EffectiveFQDN, recordValue)
err := client.DeleteTXTRecord(context.Background(), 1, info.EffectiveFQDN, recordValue)
require.NoError(t, err)
}

View file

@ -0,0 +1,73 @@
package internal
// Some fields have been omitted from the structs
// because they are not required for this application.
type DomainListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedDomainList `json:"_embedded"`
}
type EmbeddedDomainList struct {
Domains []*Domain `json:"domains"`
}
type Domain struct {
ID int `json:"id"`
Name string `json:"name"`
}
type DomainResponse struct {
ID int `json:"id"`
Name string `json:"name"`
Created string `json:"created"`
PaidUp string `json:"payed_up"`
Active bool `json:"active"`
}
type NameserverResponse struct {
General NameserverGeneral `json:"general"`
Nameservers []*Nameserver `json:"nameservers"`
SOA NameserverSOA `json:"soa"`
}
type NameserverGeneral struct {
IPv4 string `json:"ip_v4"`
IPv6 string `json:"ip_v6"`
IncludeWWW bool `json:"include_www"`
}
type NameserverSOA struct {
Mail string `json:"mail"`
Refresh int `json:"refresh"`
Retry int `json:"retry"`
Expiry int `json:"expiry"`
TTL int `json:"ttl"`
}
type Nameserver struct {
Name string `json:"name"`
}
type RecordListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedRecordList `json:"_embedded"`
}
type EmbeddedRecordList struct {
Records []*Record `json:"records"`
}
type Record struct {
Name string `json:"name"`
Value string `json:"value"`
TTL int `json:"ttl"`
Priority int `json:"priority"`
Type string `json:"type"`
}

View file

@ -93,11 +93,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := getZone(info.EffectiveFQDN)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("civo: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err)
return fmt.Errorf("civo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
zone := dns01.UnFqdn(authZone)
dnsDomain, err := d.client.GetDNSDomain(zone)
if err != nil {
return fmt.Errorf("civo: %w", err)
@ -125,11 +127,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := getZone(info.EffectiveFQDN)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("civo: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err)
return fmt.Errorf("civo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
zone := dns01.UnFqdn(authZone)
dnsDomain, err := d.client.GetDNSDomain(zone)
if err != nil {
return fmt.Errorf("civo: %w", err)
@ -166,12 +170,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}
func getZone(fqdn string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
}
return dns01.UnFqdn(authZone), nil
}

View file

@ -2,6 +2,7 @@
package clouddns
import (
"context"
"errors"
"fmt"
"net/http"
@ -89,10 +90,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
client.HTTPClient = config.HTTPClient
}
return &DNSProvider{
client: client,
config: config,
}, nil
return &DNSProvider{client: client, config: config}, nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.
@ -107,12 +105,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("clouddns: %w", err)
return fmt.Errorf("clouddns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
err = d.client.AddRecord(authZone, info.EffectiveFQDN, info.Value)
ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil {
return fmt.Errorf("clouddns: %w", err)
return err
}
err = d.client.AddRecord(ctx, authZone, info.EffectiveFQDN, info.Value)
if err != nil {
return fmt.Errorf("clouddns: add record: %w", err)
}
return nil
@ -124,12 +127,17 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("clouddns: %w", err)
return fmt.Errorf("clouddns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
err = d.client.DeleteRecord(authZone, info.EffectiveFQDN)
ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil {
return fmt.Errorf("clouddns: %w", err)
return err
}
err = d.client.DeleteRecord(ctx, authZone, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("clouddns: delete record: %w", err)
}
return nil

View file

@ -2,117 +2,127 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const (
apiBaseURL = "https://admin.vshosting.cloud/clouddns"
loginURL = "https://admin.vshosting.cloud/api/public/auth/login"
)
const apiBaseURL = "https://admin.vshosting.cloud/clouddns"
const authorizationHeader = "Authorization"
// Client handles all communication with CloudDNS API.
type Client struct {
AccessToken string
ClientID string
Email string
Password string
TTL int
HTTPClient *http.Client
clientID string
email string
password string
ttl int
apiBaseURL string
loginURL string
apiBaseURL *url.URL
loginURL *url.URL
HTTPClient *http.Client
}
// NewClient returns a Client instance configured to handle CloudDNS API communication.
func NewClient(clientID, email, password string, ttl int) *Client {
baseURL, _ := url.Parse(apiBaseURL)
loginBaseURL, _ := url.Parse(loginURL)
return &Client{
ClientID: clientID,
Email: email,
Password: password,
TTL: ttl,
HTTPClient: &http.Client{},
apiBaseURL: apiBaseURL,
loginURL: loginURL,
clientID: clientID,
email: email,
password: password,
ttl: ttl,
apiBaseURL: baseURL,
loginURL: loginBaseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}
}
// AddRecord is a high level method to add a new record into CloudDNS zone.
func (c *Client) AddRecord(zone, recordName, recordValue string) error {
domain, err := c.getDomain(zone)
func (c *Client) AddRecord(ctx context.Context, zone, recordName, recordValue string) error {
domain, err := c.getDomain(ctx, zone)
if err != nil {
return err
}
record := Record{DomainID: domain.ID, Name: recordName, Value: recordValue, Type: "TXT"}
err = c.addTxtRecord(record)
err = c.addTxtRecord(ctx, record)
if err != nil {
return err
}
return c.publishRecords(domain.ID)
return c.publishRecords(ctx, domain.ID)
}
// DeleteRecord is a high level method to remove a record from zone.
func (c *Client) DeleteRecord(zone, recordName string) error {
domain, err := c.getDomain(zone)
func (c *Client) DeleteRecord(ctx context.Context, zone, recordName string) error {
domain, err := c.getDomain(ctx, zone)
if err != nil {
return err
}
record, err := c.getRecord(domain.ID, recordName)
record, err := c.getRecord(ctx, domain.ID, recordName)
if err != nil {
return err
}
err = c.deleteRecord(record)
err = c.deleteRecord(ctx, record)
if err != nil {
return err
}
return c.publishRecords(domain.ID)
return c.publishRecords(ctx, domain.ID)
}
func (c *Client) addTxtRecord(record Record) error {
body, err := json.Marshal(record)
func (c *Client) addTxtRecord(ctx context.Context, record Record) error {
endpoint := c.apiBaseURL.JoinPath("record-txt")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return err
}
_, err = c.doAPIRequest(http.MethodPost, "record-txt", bytes.NewReader(body))
return c.do(req, nil)
}
func (c *Client) deleteRecord(ctx context.Context, record Record) error {
endpoint := c.apiBaseURL.JoinPath("record", record.ID)
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
func (c *Client) deleteRecord(record Record) error {
endpoint := fmt.Sprintf("record/%s", record.ID)
_, err := c.doAPIRequest(http.MethodDelete, endpoint, nil)
return err
return c.do(req, nil)
}
func (c *Client) getDomain(zone string) (Domain, error) {
func (c *Client) getDomain(ctx context.Context, zone string) (Domain, error) {
searchQuery := SearchQuery{
Search: []Search{
{Name: "clientId", Operator: "eq", Value: c.ClientID},
{Name: "clientId", Operator: "eq", Value: c.clientID},
{Name: "domainName", Operator: "eq", Value: zone},
},
}
body, err := json.Marshal(searchQuery)
if err != nil {
return Domain{}, err
}
endpoint := c.apiBaseURL.JoinPath("domain", "search")
resp, err := c.doAPIRequest(http.MethodPost, "domain/search", bytes.NewReader(body))
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, searchQuery)
if err != nil {
return Domain{}, err
}
var result SearchResponse
err = json.Unmarshal(resp, &result)
err = c.do(req, &result)
if err != nil {
return Domain{}, err
}
@ -124,15 +134,16 @@ func (c *Client) getDomain(zone string) (Domain, error) {
return result.Items[0], nil
}
func (c *Client) getRecord(domainID, recordName string) (Record, error) {
endpoint := fmt.Sprintf("domain/%s", domainID)
resp, err := c.doAPIRequest(http.MethodGet, endpoint, nil)
func (c *Client) getRecord(ctx context.Context, domainID, recordName string) (Record, error) {
endpoint := c.apiBaseURL.JoinPath("domain", domainID)
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return Record{}, err
}
var result DomainInfo
err = json.Unmarshal(resp, &result)
err = c.do(req, &result)
if err != nil {
return Record{}, err
}
@ -146,116 +157,85 @@ func (c *Client) getRecord(domainID, recordName string) (Record, error) {
return Record{}, fmt.Errorf("record not found: domainID %s, name %s", domainID, recordName)
}
func (c *Client) publishRecords(domainID string) error {
body, err := json.Marshal(DomainInfo{SoaTTL: c.TTL})
func (c *Client) publishRecords(ctx context.Context, domainID string) error {
endpoint := c.apiBaseURL.JoinPath("domain", domainID, "publish")
payload := DomainInfo{SoaTTL: c.ttl}
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, payload)
if err != nil {
return err
}
endpoint := fmt.Sprintf("domain/%s/publish", domainID)
_, err = c.doAPIRequest(http.MethodPut, endpoint, bytes.NewReader(body))
return err
return c.do(req, nil)
}
func (c *Client) login() error {
authorization := Authorization{Email: c.Email, Password: c.Password}
func (c *Client) do(req *http.Request, result any) error {
at := getAccessToken(req.Context())
if at != "" {
req.Header.Set(authorizationHeader, "Bearer "+at)
}
body, err := json.Marshal(authorization)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return err
return errutils.NewHTTPDoError(req, err)
}
req, err := http.NewRequest(http.MethodPost, c.loginURL, bytes.NewReader(body))
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return parseError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
req.Header.Set("Content-Type", "application/json")
content, err := c.doRequest(req)
err = json.Unmarshal(raw, result)
if err != nil {
return err
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
var result AuthResponse
err = json.Unmarshal(content, &result)
if err != nil {
return err
}
c.AccessToken = result.Auth.AccessToken
return nil
}
func (c *Client) doAPIRequest(method, endpoint string, body io.Reader) ([]byte, error) {
if c.AccessToken == "" {
err := c.login()
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
url := fmt.Sprintf("%s/%s", c.apiBaseURL, endpoint)
req, err := c.newRequest(method, url, body)
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to create request: %w", err)
}
content, err := c.doRequest(req)
if err != nil {
return nil, err
}
return content, nil
}
func (c *Client) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.AccessToken))
}
return req, nil
}
func (c *Client) doRequest(req *http.Request) ([]byte, error) {
resp, err := c.HTTPClient.Do(req)
func parseError(req *http.Request, resp *http.Response) error {
raw, _ := io.ReadAll(resp.Body)
var response APIError
err := json.Unmarshal(raw, &response)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return nil, readError(req, resp)
return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)
}
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return content, nil
}
func readError(req *http.Request, resp *http.Response) error {
content, err := io.ReadAll(resp.Body)
if err != nil {
return errors.New(toUnreadableBodyMessage(req, content))
}
var errInfo APIError
err = json.Unmarshal(content, &errInfo)
if err != nil {
return fmt.Errorf("APIError unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content))
}
return fmt.Errorf("HTTP %d: code %v: %s", resp.StatusCode, errInfo.Error.Code, errInfo.Error.Message)
}
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
return fmt.Errorf("[status code %d] %w", resp.StatusCode, response.Error)
}

View file

@ -1,16 +1,33 @@
package internal
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestClient_AddRecord(t *testing.T) {
func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient("clientID", "email@example.com", "secret", 300)
client.HTTPClient = server.Client()
client.apiBaseURL, _ = url.Parse(server.URL + "/api")
client.loginURL, _ = url.Parse(server.URL + "/login")
return client, mux
}
func TestClient_AddRecord(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) {
response := SearchResponse{
@ -45,19 +62,12 @@ func TestClient_AddRecord(t *testing.T) {
}
})
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient("clientID", "email@example.com", "secret", 300)
client.apiBaseURL = server.URL + "/api"
client.loginURL = server.URL + "/login"
err := client.AddRecord("example.com", "_acme-challenge.example.com", "txt")
err := client.AddRecord(context.Background(), "example.com", "_acme-challenge.example.com", "txt")
require.NoError(t, err)
}
func TestClient_DeleteRecord(t *testing.T) {
mux := http.NewServeMux()
client, mux := setupTest(t)
mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) {
response := SearchResponse{
@ -114,13 +124,9 @@ func TestClient_DeleteRecord(t *testing.T) {
}
})
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
ctx, err := client.CreateAuthenticatedContext(context.Background())
require.NoError(t, err)
client := NewClient("clientID", "email@example.com", "secret", 300)
client.apiBaseURL = server.URL + "/api"
client.loginURL = server.URL + "/login"
err := client.DeleteRecord("example.com", "_acme-challenge.example.com")
err = client.DeleteRecord(ctx, "example.com", "_acme-challenge.example.com")
require.NoError(t, err)
}

View file

@ -0,0 +1,47 @@
package internal
import (
"context"
"net/http"
)
const loginURL = "https://admin.vshosting.cloud/api/public/auth/login"
type token string
const accessTokenKey token = "accessToken"
func (c *Client) login(ctx context.Context) (*AuthResponse, error) {
authorization := Authorization{Email: c.email, Password: c.password}
req, err := newJSONRequest(ctx, http.MethodPost, c.loginURL, authorization)
if err != nil {
return nil, err
}
var result AuthResponse
err = c.do(req, &result)
if err != nil {
return nil, err
}
return &result, nil
}
func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) {
tok, err := c.login(ctx)
if err != nil {
return nil, err
}
return context.WithValue(ctx, accessTokenKey, tok.Auth.AccessToken), nil
}
func getAccessToken(ctx context.Context) string {
tok, ok := ctx.Value(accessTokenKey).(string)
if !ok {
return ""
}
return tok
}

View file

@ -0,0 +1,46 @@
package internal
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient_CreateAuthenticatedContext(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/login", func(rw http.ResponseWriter, req *http.Request) {
response := AuthResponse{
Auth: Auth{
AccessToken: "at",
RefreshToken: "",
},
}
err := json.NewEncoder(rw).Encode(response)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/api/record/xxx", func(rw http.ResponseWriter, req *http.Request) {
authorization := req.Header.Get(authorizationHeader)
if authorization != "Bearer at" {
http.Error(rw, "invalid credential: "+authorization, http.StatusUnauthorized)
return
}
})
ctx, err := client.CreateAuthenticatedContext(context.Background())
require.NoError(t, err)
at := getAccessToken(ctx)
assert.Equal(t, "at", at)
err = client.deleteRecord(ctx, Record{ID: "xxx"})
require.NoError(t, err)
}

View file

@ -1,5 +1,7 @@
package internal
import "fmt"
type APIError struct {
Error ErrorContent `json:"error"`
}
@ -9,6 +11,10 @@ type ErrorContent struct {
Message string `json:"message,omitempty"`
}
func (e ErrorContent) Error() string {
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
type Authorization struct {
Email string `json:"email,omitempty"`
Password string `json:"password,omitempty"`

View file

@ -126,7 +126,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("cloudflare: %w", err)
return fmt.Errorf("cloudflare: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
zoneID, err := d.client.ZoneIDByName(authZone)
@ -165,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("cloudflare: %w", err)
return fmt.Errorf("cloudflare: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
zoneID, err := d.client.ZoneIDByName(authZone)

View file

@ -2,6 +2,7 @@
package cloudns
import (
"context"
"errors"
"fmt"
"net/http"
@ -104,29 +105,33 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := d.client.GetZone(info.EffectiveFQDN)
ctx := context.Background()
zone, err := d.client.GetZone(ctx, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("ClouDNS: %w", err)
}
err = d.client.AddTxtRecord(zone.Name, info.EffectiveFQDN, info.Value, d.config.TTL)
err = d.client.AddTxtRecord(ctx, zone.Name, info.EffectiveFQDN, info.Value, d.config.TTL)
if err != nil {
return fmt.Errorf("ClouDNS: %w", err)
}
return d.waitNameservers(domain, zone)
return d.waitNameservers(ctx, domain, zone)
}
// CleanUp removes the TXT records matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := d.client.GetZone(info.EffectiveFQDN)
ctx := context.Background()
zone, err := d.client.GetZone(ctx, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("ClouDNS: %w", err)
}
records, err := d.client.ListTxtRecords(zone.Name, info.EffectiveFQDN)
records, err := d.client.ListTxtRecords(ctx, zone.Name, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("ClouDNS: %w", err)
}
@ -136,7 +141,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
}
for _, record := range records {
err = d.client.RemoveTxtRecord(record.ID, zone.Name)
err = d.client.RemoveTxtRecord(ctx, record.ID, zone.Name)
if err != nil {
return fmt.Errorf("ClouDNS: %w", err)
}
@ -153,9 +158,9 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
// waitNameservers At the time of writing 4 servers are found as authoritative, but 8 are reported during the sync.
// If this is not done, the secondary verification done by Let's Encrypt server will fail quire a bit.
func (d *DNSProvider) waitNameservers(domain string, zone *internal.Zone) error {
func (d *DNSProvider) waitNameservers(ctx context.Context, domain string, zone *internal.Zone) error {
return wait.For("Nameserver sync on "+domain, d.config.PropagationTimeout, d.config.PollingInterval, func() (bool, error) {
syncProgress, err := d.client.GetUpdateStatus(zone.Name)
syncProgress, err := d.client.GetUpdateStatus(ctx, zone.Name)
if err != nil {
return false, err
}

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -8,8 +9,10 @@ import (
"net/http"
"net/url"
"strconv"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const defaultBaseURL = "https://api.cloudns.net/dns/"
@ -19,8 +22,9 @@ type Client struct {
authID string
subAuthID string
authPassword string
HTTPClient *http.Client
BaseURL *url.URL
HTTPClient *http.Client
}
// NewClient creates a ClouDNS client.
@ -42,16 +46,16 @@ func NewClient(authID, subAuthID, authPassword string) (*Client, error) {
authID: authID,
subAuthID: subAuthID,
authPassword: authPassword,
HTTPClient: &http.Client{},
BaseURL: baseURL,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil
}
// GetZone Get domain name information for a FQDN.
func (c *Client) GetZone(authFQDN string) (*Zone, error) {
func (c *Client) GetZone(ctx context.Context, authFQDN string) (*Zone, error) {
authZone, err := dns01.FindZoneByFqdn(authFQDN)
if err != nil {
return nil, err
return nil, fmt.Errorf("could not find zone for FQDN %q: %w", authFQDN, err)
}
authZoneName := dns01.UnFqdn(authZone)
@ -62,16 +66,21 @@ func (c *Client) GetZone(authFQDN string) (*Zone, error) {
q.Set("domain-name", authZoneName)
endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint)
req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil {
return nil, err
}
var zone Zone
if len(result) > 0 {
if err = json.Unmarshal(result, &zone); err != nil {
return nil, fmt.Errorf("failed to unmarshal zone: %w", err)
if len(rawMessage) > 0 {
if err = json.Unmarshal(rawMessage, &zone); err != nil {
return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
}
}
@ -83,7 +92,7 @@ func (c *Client) GetZone(authFQDN string) (*Zone, error) {
}
// FindTxtRecord returns the TXT record a zone ID and a FQDN.
func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) {
func (c *Client) FindTxtRecord(ctx context.Context, zoneName, fqdn string) (*TXTRecord, error) {
subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName)
if err != nil {
return nil, err
@ -97,19 +106,24 @@ func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) {
q.Set("type", "TXT")
endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint)
req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil {
return nil, err
}
// the API returns [] when there is no records.
if string(result) == "[]" {
if string(rawMessage) == "[]" {
return nil, nil
}
var records map[string]TXTRecord
if err = json.Unmarshal(result, &records); err != nil {
return nil, fmt.Errorf("failed to unmarshall TXT records: %w: %s", err, string(result))
if err = json.Unmarshal(rawMessage, &records); err != nil {
return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
}
for _, record := range records {
@ -122,7 +136,7 @@ func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) {
}
// ListTxtRecords returns the TXT records a zone ID and a FQDN.
func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) {
func (c *Client) ListTxtRecords(ctx context.Context, zoneName, fqdn string) ([]TXTRecord, error) {
subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName)
if err != nil {
return nil, err
@ -136,19 +150,24 @@ func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) {
q.Set("type", "TXT")
endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint)
req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil {
return nil, err
}
// the API returns [] when there is no records.
if string(result) == "[]" {
if string(rawMessage) == "[]" {
return nil, nil
}
var raw map[string]TXTRecord
if err = json.Unmarshal(result, &raw); err != nil {
return nil, fmt.Errorf("failed to unmarshall TXT records: %w: %s", err, string(result))
if err = json.Unmarshal(rawMessage, &raw); err != nil {
return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
}
var records []TXTRecord
@ -162,7 +181,7 @@ func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) {
}
// AddTxtRecord adds a TXT record.
func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error {
func (c *Client) AddTxtRecord(ctx context.Context, zoneName, fqdn, value string, ttl int) error {
subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName)
if err != nil {
return err
@ -178,14 +197,19 @@ func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error {
q.Set("record-type", "TXT")
endpoint.RawQuery = q.Encode()
raw, err := c.doRequest(http.MethodPost, endpoint)
req, err := c.newRequest(ctx, http.MethodPost, endpoint)
if err != nil {
return err
}
rawMessage, err := c.do(req)
if err != nil {
return err
}
resp := apiResponse{}
if err = json.Unmarshal(raw, &resp); err != nil {
return fmt.Errorf("failed to unmarshal API response: %w: %s", err, string(raw))
if err = json.Unmarshal(rawMessage, &resp); err != nil {
return errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
}
if resp.Status != "Success" {
@ -196,7 +220,7 @@ func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error {
}
// RemoveTxtRecord removes a TXT record.
func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error {
func (c *Client) RemoveTxtRecord(ctx context.Context, recordID int, zoneName string) error {
endpoint := c.BaseURL.JoinPath("delete-record.json")
q := endpoint.Query()
@ -204,14 +228,19 @@ func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error {
q.Set("record-id", strconv.Itoa(recordID))
endpoint.RawQuery = q.Encode()
raw, err := c.doRequest(http.MethodPost, endpoint)
req, err := c.newRequest(ctx, http.MethodPost, endpoint)
if err != nil {
return err
}
rawMessage, err := c.do(req)
if err != nil {
return err
}
resp := apiResponse{}
if err = json.Unmarshal(raw, &resp); err != nil {
return fmt.Errorf("failed to unmarshal API response: %w: %s", err, string(raw))
if err = json.Unmarshal(rawMessage, &resp); err != nil {
return errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
}
if resp.Status != "Success" {
@ -222,26 +251,31 @@ func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error {
}
// GetUpdateStatus gets sync progress of all CloudDNS NS servers.
func (c *Client) GetUpdateStatus(zoneName string) (*SyncProgress, error) {
func (c *Client) GetUpdateStatus(ctx context.Context, zoneName string) (*SyncProgress, error) {
endpoint := c.BaseURL.JoinPath("update-status.json")
q := endpoint.Query()
q.Set("domain-name", zoneName)
endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint)
req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil {
return nil, err
}
// the API returns [] when there is no records.
if string(result) == "[]" {
if string(rawMessage) == "[]" {
return nil, errors.New("no nameservers records returned")
}
var records []UpdateRecord
if err = json.Unmarshal(result, &records); err != nil {
return nil, fmt.Errorf("failed to unmarshal UpdateRecord: %w: %s", err, string(result))
if err = json.Unmarshal(rawMessage, &records); err != nil {
return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
}
updatedCount := 0
@ -254,33 +288,8 @@ func (c *Client) GetUpdateStatus(zoneName string) (*SyncProgress, error) {
return &SyncProgress{Complete: updatedCount == len(records), Updated: updatedCount, Total: len(records)}, nil
}
func (c *Client) doRequest(method string, uri *url.URL) (json.RawMessage, error) {
req, err := c.buildRequest(method, uri)
if err != nil {
return nil, err
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.New(toUnreadableBodyMessage(req, content))
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("invalid code (%d), error: %s", resp.StatusCode, content)
}
return content, nil
}
func (c *Client) buildRequest(method string, uri *url.URL) (*http.Request, error) {
q := uri.Query()
func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL) (*http.Request, error) {
q := endpoint.Query()
if c.subAuthID != "" {
q.Set("sub-auth-id", c.subAuthID)
@ -290,18 +299,34 @@ func (c *Client) buildRequest(method string, uri *url.URL) (*http.Request, error
q.Set("auth-password", c.authPassword)
uri.RawQuery = q.Encode()
endpoint.RawQuery = q.Encode()
req, err := http.NewRequest(method, uri.String(), nil)
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), nil)
if err != nil {
return nil, fmt.Errorf("invalid request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
return req, nil
}
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
func (c *Client) do(req *http.Request) (json.RawMessage, error) {
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errutils.NewReadResponseError(req, resp.StatusCode, err)
}
return raw, nil
}
// Rounds the given TTL in seconds to the next accepted value.

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -11,6 +12,21 @@ import (
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, subAuthID string, handler http.HandlerFunc) *Client {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client, err := NewClient("myAuthID", subAuthID, "myAuthPassword")
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
client.HTTPClient = server.Client()
return client
}
func handlerMock(method string, jsonData []byte) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
@ -109,22 +125,16 @@ func TestClient_GetZone(t *testing.T) {
authFQDN: "_acme-challenge.foo.com.",
apiResponse: `[{}]`,
expected: expected{
errorMsg: "failed to unmarshal zone: json: cannot unmarshal array into Go value of type internal.Zone",
errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.Zone",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse)))
t.Cleanup(server.Close)
client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse)))
client, err := NewClient("myAuthID", "", "myAuthPassword")
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
zone, err := client.GetZone(test.authFQDN)
zone, err := client.GetZone(context.Background(), test.authFQDN)
if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg)
@ -222,22 +232,16 @@ func TestClient_FindTxtRecord(t *testing.T) {
zoneName: "example.com",
apiResponse: `[{}]`,
expected: expected{
errorMsg: "failed to unmarshall TXT records: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord: [{}]",
errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse)))
t.Cleanup(server.Close)
client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse)))
client, err := NewClient("myAuthID", "", "myAuthPassword")
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
txtRecord, err := client.FindTxtRecord(test.zoneName, test.authFQDN)
txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneName, test.authFQDN)
if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg)
@ -337,22 +341,16 @@ func TestClient_ListTxtRecord(t *testing.T) {
zoneName: "example.com",
apiResponse: `[{}]`,
expected: expected{
errorMsg: "failed to unmarshall TXT records: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord: [{}]",
errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse)))
t.Cleanup(server.Close)
client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse)))
client, err := NewClient("myAuthID", "", "myAuthPassword")
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
txtRecords, err := client.ListTxtRecords(test.zoneName, test.authFQDN)
txtRecords, err := client.ListTxtRecords(context.Background(), test.zoneName, test.authFQDN)
if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg)
@ -440,14 +438,14 @@ func TestClient_AddTxtRecord(t *testing.T) {
apiResponse: `[{}]`,
expected: expected{
query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=bar.com&host=_acme-challenge&record=TXTtxtTXTtxtTXTtxtTXTtxt&record-type=TXT&ttl=300`,
errorMsg: "failed to unmarshal API response: json: cannot unmarshal array into Go value of type internal.apiResponse: [{}]",
errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.apiResponse",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
client := setupTest(t, test.subAuthID, func(rw http.ResponseWriter, req *http.Request) {
if test.expected.query != req.URL.RawQuery {
msg := fmt.Sprintf("got: %s, want: %s", test.expected.query, req.URL.RawQuery)
http.Error(rw, msg, http.StatusBadRequest)
@ -455,15 +453,9 @@ func TestClient_AddTxtRecord(t *testing.T) {
}
handlerMock(http.MethodPost, []byte(test.apiResponse))(rw, req)
}))
t.Cleanup(server.Close)
})
client, err := NewClient(test.authID, test.subAuthID, "myAuthPassword")
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
err = client.AddTxtRecord(test.zoneName, test.authFQDN, test.value, test.ttl)
err := client.AddTxtRecord(context.Background(), test.zoneName, test.authFQDN, test.value, test.ttl)
if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg)
@ -513,7 +505,7 @@ func TestClient_RemoveTxtRecord(t *testing.T) {
apiResponse: `[{}]`,
expected: expected{
query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=foo-plus.com&record-id=44`,
errorMsg: "failed to unmarshal API response: json: cannot unmarshal array into Go value of type internal.apiResponse: [{}]",
errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.apiResponse",
},
},
}
@ -536,7 +528,7 @@ func TestClient_RemoveTxtRecord(t *testing.T) {
client.BaseURL, _ = url.Parse(server.URL)
err = client.RemoveTxtRecord(test.id, test.zoneName)
err = client.RemoveTxtRecord(context.Background(), test.id, test.zoneName)
if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg)
@ -592,7 +584,7 @@ func TestClient_GetUpdateStatus(t *testing.T) {
authFQDN: "_acme-challenge.foo.com.",
zoneName: "test-zone",
apiResponse: `[x]`,
expected: expected{errorMsg: "failed to unmarshal UpdateRecord: invalid character 'x' looking for beginning of value: [x]"},
expected: expected{errorMsg: "unable to unmarshal response: [status code: 200] body: [x] error: invalid character 'x' looking for beginning of value"},
},
}
@ -606,7 +598,7 @@ func TestClient_GetUpdateStatus(t *testing.T) {
client.BaseURL, _ = url.Parse(server.URL)
syncProgress, err := client.GetUpdateStatus(test.zoneName)
syncProgress, err := client.GetUpdateStatus(context.Background(), test.zoneName)
if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg)

View file

@ -2,6 +2,7 @@
package cloudxns
import (
"context"
"errors"
"fmt"
"net/http"
@ -59,7 +60,7 @@ type DNSProvider struct {
func NewDNSProvider() (*DNSProvider, error) {
values, err := env.Get(EnvAPIKey, EnvSecretKey)
if err != nil {
return nil, fmt.Errorf("CloudXNS: %w", err)
return nil, fmt.Errorf("cloudxns: %w", err)
}
config := NewDefaultConfig()
@ -72,15 +73,17 @@ func NewDNSProvider() (*DNSProvider, error) {
// NewDNSProviderConfig return a DNSProvider instance configured for CloudXNS.
func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if config == nil {
return nil, errors.New("CloudXNS: the configuration of the DNS provider is nil")
return nil, errors.New("cloudxns: the configuration of the DNS provider is nil")
}
client, err := internal.NewClient(config.APIKey, config.SecretKey)
if err != nil {
return nil, err
return nil, fmt.Errorf("cloudxns: %w", err)
}
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
}
return &DNSProvider{client: client, config: config}, nil
}
@ -89,29 +92,43 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
challengeInfo := dns01.GetChallengeInfo(domain, keyAuth)
info, err := d.client.GetDomainInformation(challengeInfo.EffectiveFQDN)
ctx := context.Background()
info, err := d.client.GetDomainInformation(ctx, challengeInfo.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("cloudxns: %w", err)
}
return d.client.AddTxtRecord(info, challengeInfo.EffectiveFQDN, challengeInfo.Value, d.config.TTL)
err = d.client.AddTxtRecord(ctx, info, challengeInfo.EffectiveFQDN, challengeInfo.Value, d.config.TTL)
if err != nil {
return fmt.Errorf("cloudxns: %w", err)
}
return nil
}
// CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
challengeInfo := dns01.GetChallengeInfo(domain, keyAuth)
info, err := d.client.GetDomainInformation(challengeInfo.EffectiveFQDN)
ctx := context.Background()
info, err := d.client.GetDomainInformation(ctx, challengeInfo.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("cloudxns: %w", err)
}
record, err := d.client.FindTxtRecord(info.ID, challengeInfo.EffectiveFQDN)
record, err := d.client.FindTxtRecord(ctx, info.ID, challengeInfo.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("cloudxns: %w", err)
}
return d.client.RemoveTxtRecord(record.RecordID, info.ID)
err = d.client.RemoveTxtRecord(ctx, record.RecordID, info.ID)
if err != nil {
return fmt.Errorf("cloudxns: %w", err)
}
return nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.

View file

@ -34,7 +34,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIKey: "",
EnvSecretKey: "",
},
expected: "CloudXNS: some credentials information are missing: CLOUDXNS_API_KEY,CLOUDXNS_SECRET_KEY",
expected: "cloudxns: some credentials information are missing: CLOUDXNS_API_KEY,CLOUDXNS_SECRET_KEY",
},
{
desc: "missing API key",
@ -42,7 +42,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIKey: "",
EnvSecretKey: "456",
},
expected: "CloudXNS: some credentials information are missing: CLOUDXNS_API_KEY",
expected: "cloudxns: some credentials information are missing: CLOUDXNS_API_KEY",
},
{
desc: "missing secret key",
@ -50,7 +50,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIKey: "123",
EnvSecretKey: "",
},
expected: "CloudXNS: some credentials information are missing: CLOUDXNS_SECRET_KEY",
expected: "cloudxns: some credentials information are missing: CLOUDXNS_SECRET_KEY",
},
}
@ -89,17 +89,17 @@ func TestNewDNSProviderConfig(t *testing.T) {
},
{
desc: "missing credentials",
expected: "CloudXNS: credentials missing: apiKey",
expected: "cloudxns: credentials missing: apiKey",
},
{
desc: "missing api key",
secretKey: "456",
expected: "CloudXNS: credentials missing: apiKey",
expected: "cloudxns: credentials missing: apiKey",
},
{
desc: "missing secret key",
apiKey: "123",
expected: "CloudXNS: credentials missing: secretKey",
expected: "cloudxns: credentials missing: secretKey",
},
}

View file

@ -2,6 +2,7 @@ package internal
import (
"bytes"
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
@ -9,83 +10,63 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const defaultBaseURL = "https://www.cloudxns.net/api2/"
type apiResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
// Data Domain information.
type Data struct {
ID string `json:"id"`
Domain string `json:"domain"`
TTL int `json:"ttl,omitempty"`
}
// TXTRecord a TXT record.
type TXTRecord struct {
ID int `json:"domain_id,omitempty"`
RecordID string `json:"record_id,omitempty"`
Host string `json:"host"`
Value string `json:"value"`
Type string `json:"type"`
LineID int `json:"line_id,string"`
TTL int `json:"ttl,string"`
}
// NewClient creates a CloudXNS client.
func NewClient(apiKey, secretKey string) (*Client, error) {
if apiKey == "" {
return nil, errors.New("CloudXNS: credentials missing: apiKey")
}
if secretKey == "" {
return nil, errors.New("CloudXNS: credentials missing: secretKey")
}
return &Client{
apiKey: apiKey,
secretKey: secretKey,
HTTPClient: &http.Client{},
BaseURL: defaultBaseURL,
}, nil
}
// Client CloudXNS client.
type Client struct {
apiKey string
secretKey string
baseURL *url.URL
HTTPClient *http.Client
BaseURL string
}
// NewClient creates a CloudXNS client.
func NewClient(apiKey, secretKey string) (*Client, error) {
if apiKey == "" {
return nil, errors.New("credentials missing: apiKey")
}
if secretKey == "" {
return nil, errors.New("credentials missing: secretKey")
}
baseURL, _ := url.Parse(defaultBaseURL)
return &Client{
apiKey: apiKey,
secretKey: secretKey,
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil
}
// GetDomainInformation Get domain name information for a FQDN.
func (c *Client) GetDomainInformation(fqdn string) (*Data, error) {
authZone, err := dns01.FindZoneByFqdn(fqdn)
func (c *Client) GetDomainInformation(ctx context.Context, fqdn string) (*Data, error) {
endpoint := c.baseURL.JoinPath("domain")
req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
result, err := c.doRequest(http.MethodGet, "domain", nil)
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return nil, err
return nil, fmt.Errorf("cloudflare: could not find zone for FQDN %q: %w", fqdn, err)
}
var domains []Data
if len(result) > 0 {
err = json.Unmarshal(result, &domains)
err = c.do(req, &domains)
if err != nil {
return nil, fmt.Errorf("CloudXNS: domains unmarshaling error: %w", err)
}
return nil, err
}
for _, data := range domains {
@ -94,20 +75,28 @@ func (c *Client) GetDomainInformation(fqdn string) (*Data, error) {
}
}
return nil, fmt.Errorf("CloudXNS: zone %s not found for domain %s", authZone, fqdn)
return nil, fmt.Errorf("zone %s not found for domain %s", authZone, fqdn)
}
// FindTxtRecord return the TXT record a zone ID and a FQDN.
func (c *Client) FindTxtRecord(zoneID, fqdn string) (*TXTRecord, error) {
result, err := c.doRequest(http.MethodGet, fmt.Sprintf("record/%s?host_id=0&offset=0&row_num=2000", zoneID), nil)
func (c *Client) FindTxtRecord(ctx context.Context, zoneID, fqdn string) (*TXTRecord, error) {
endpoint := c.baseURL.JoinPath("record", zoneID)
query := endpoint.Query()
query.Set("host_id", "0")
query.Set("offset", "0")
query.Set("row_num", "2000")
endpoint.RawQuery = query.Encode()
req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
var records []TXTRecord
err = json.Unmarshal(result, &records)
err = c.do(req, &records)
if err != nil {
return nil, fmt.Errorf("CloudXNS: TXT record unmarshaling error: %w", err)
return nil, err
}
for _, record := range records {
@ -116,22 +105,24 @@ func (c *Client) FindTxtRecord(zoneID, fqdn string) (*TXTRecord, error) {
}
}
return nil, fmt.Errorf("CloudXNS: no existing record found for %q", fqdn)
return nil, fmt.Errorf("no existing record found for %q", fqdn)
}
// AddTxtRecord add a TXT record.
func (c *Client) AddTxtRecord(info *Data, fqdn, value string, ttl int) error {
func (c *Client) AddTxtRecord(ctx context.Context, info *Data, fqdn, value string, ttl int) error {
id, err := strconv.Atoi(info.ID)
if err != nil {
return fmt.Errorf("CloudXNS: invalid zone ID: %w", err)
return fmt.Errorf("invalid zone ID: %w", err)
}
endpoint := c.baseURL.JoinPath("record")
subDomain, err := dns01.ExtractSubDomain(fqdn, info.Domain)
if err != nil {
return fmt.Errorf("CloudXNS: %w", err)
return err
}
payload := TXTRecord{
record := TXTRecord{
ID: id,
Host: subDomain,
Value: value,
@ -140,74 +131,91 @@ func (c *Client) AddTxtRecord(info *Data, fqdn, value string, ttl int) error {
TTL: ttl,
}
body, err := json.Marshal(payload)
req, err := c.newRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return fmt.Errorf("CloudXNS: record unmarshaling error: %w", err)
return err
}
_, err = c.doRequest(http.MethodPost, "record", body)
return err
return c.do(req, nil)
}
// RemoveTxtRecord remove a TXT record.
func (c *Client) RemoveTxtRecord(recordID, zoneID string) error {
_, err := c.doRequest(http.MethodDelete, fmt.Sprintf("record/%s/%s", recordID, zoneID), nil)
func (c *Client) RemoveTxtRecord(ctx context.Context, recordID, zoneID string) error {
endpoint := c.baseURL.JoinPath("record", recordID, zoneID)
req, err := c.newRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
func (c *Client) doRequest(method, uri string, body []byte) (json.RawMessage, error) {
req, err := c.buildRequest(method, uri, body)
if err != nil {
return nil, err
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("CloudXNS: %w", err)
return errutils.NewHTTPDoError(req, err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
content, err := io.ReadAll(resp.Body)
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("CloudXNS: %s", toUnreadableBodyMessage(req, content))
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
var r apiResponse
err = json.Unmarshal(content, &r)
var response apiResponse
err = json.Unmarshal(raw, &response)
if err != nil {
return nil, fmt.Errorf("CloudXNS: response unmashaling error: %w: %s", err, toUnreadableBodyMessage(req, content))
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
if r.Code != 1 {
return nil, fmt.Errorf("CloudXNS: invalid code (%v), error: %s", r.Code, r.Message)
}
return r.Data, nil
if response.Code != 1 {
return fmt.Errorf("[status code %d] invalid code (%v) error: %s", resp.StatusCode, response.Code, response.Message)
}
func (c *Client) buildRequest(method, uri string, body []byte) (*http.Request, error) {
url := c.BaseURL + uri
if result == nil {
return nil
}
req, err := http.NewRequest(method, url, bytes.NewReader(body))
if len(response.Data) == 0 {
return nil
}
err = json.Unmarshal(response.Data, result)
if err != nil {
return nil, fmt.Errorf("CloudXNS: invalid request: %w", err)
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
requestDate := time.Now().Format(time.RFC1123Z)
req.Header.Set("API-KEY", c.apiKey)
req.Header.Set("API-REQUEST-DATE", requestDate)
req.Header.Set("API-HMAC", c.hmac(url, requestDate, string(body)))
req.Header.Set("API-HMAC", c.hmac(endpoint.String(), requestDate, buf.String()))
req.Header.Set("API-FORMAT", "json")
return req, nil
}
func (c *Client) hmac(url, date, body string) string {
sum := md5.Sum([]byte(c.apiKey + url + body + date + c.secretKey))
func (c *Client) hmac(endpoint, date, body string) string {
sum := md5.Sum([]byte(c.apiKey + endpoint + body + date + c.secretKey))
return hex.EncodeToString(sum[:])
}
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
}

View file

@ -1,19 +1,35 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func handlerMock(method string, response *apiResponse, data interface{}) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
func setupTest(t *testing.T, handler http.HandlerFunc) *Client {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client, _ := NewClient("myKey", "mySecret")
client.baseURL, _ = url.Parse(server.URL + "/")
client.HTTPClient = server.Client()
return client
}
func handlerMock(method string, response *apiResponse, data interface{}) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
content, err := json.Marshal(apiResponse{
Code: 999, // random code only for the test
@ -47,10 +63,10 @@ func handlerMock(method string, response *apiResponse, data interface{}) http.Ha
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
}
}
func TestClientGetDomainInformation(t *testing.T) {
func TestClient_GetDomainInformation(t *testing.T) {
type result struct {
domain *Data
error bool
@ -106,13 +122,9 @@ func TestClientGetDomainInformation(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, test.response, test.data))
t.Cleanup(server.Close)
client := setupTest(t, handlerMock(http.MethodGet, test.response, test.data))
client, _ := NewClient("myKey", "mySecret")
client.BaseURL = server.URL + "/"
domain, err := client.GetDomainInformation(test.fqdn)
domain, err := client.GetDomainInformation(context.Background(), test.fqdn)
if test.expected.error {
require.Error(t, err)
@ -124,7 +136,7 @@ func TestClientGetDomainInformation(t *testing.T) {
}
}
func TestClientFindTxtRecord(t *testing.T) {
func TestClient_FindTxtRecord(t *testing.T) {
type result struct {
txtRecord *TXTRecord
error bool
@ -210,13 +222,9 @@ func TestClientFindTxtRecord(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, test.response, test.txtRecords))
t.Cleanup(server.Close)
client := setupTest(t, handlerMock(http.MethodGet, test.response, test.txtRecords))
client, _ := NewClient("myKey", "mySecret")
client.BaseURL = server.URL + "/"
txtRecord, err := client.FindTxtRecord(test.zoneID, test.fqdn)
txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneID, test.fqdn)
if test.expected.error {
require.Error(t, err)
@ -228,7 +236,7 @@ func TestClientFindTxtRecord(t *testing.T) {
}
}
func TestClientAddTxtRecord(t *testing.T) {
func TestClient_AddTxtRecord(t *testing.T) {
testCases := []struct {
desc string
domain *Data
@ -267,21 +275,17 @@ func TestClientAddTxtRecord(t *testing.T) {
Code: 1,
}
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) {
assert.NotNil(t, req.Body)
content, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, test.expected, string(content))
assert.Equal(t, test.expected, string(bytes.TrimSpace(content)))
handlerMock(http.MethodPost, response, nil).ServeHTTP(rw, req)
}))
t.Cleanup(server.Close)
})
client, _ := NewClient("myKey", "mySecret")
client.BaseURL = server.URL + "/"
err := client.AddTxtRecord(test.domain, test.fqdn, test.value, test.ttl)
err := client.AddTxtRecord(context.Background(), test.domain, test.fqdn, test.value, test.ttl)
require.NoError(t, err)
})
}

View file

@ -0,0 +1,28 @@
package internal
import "encoding/json"
type apiResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
// Data Domain information.
type Data struct {
ID string `json:"id"`
Domain string `json:"domain"`
TTL int `json:"ttl,omitempty"`
}
// TXTRecord a TXT record.
type TXTRecord struct {
ID int `json:"domain_id,omitempty"`
RecordID string `json:"record_id,omitempty"`
Host string `json:"host"`
Value string `json:"value"`
Type string `json:"type"`
LineID int `json:"line_id,string"`
TTL int `json:"ttl,string"`
}

View file

@ -2,6 +2,7 @@
package conoha
import (
"context"
"errors"
"fmt"
"net/http"
@ -85,6 +86,15 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("conoha: some credentials information are missing")
}
identifier, err := internal.NewIdentifier(config.Region)
if err != nil {
return nil, fmt.Errorf("conoha: failed to create identity client: %w", err)
}
if config.HTTPClient != nil {
identifier.HTTPClient = config.HTTPClient
}
auth := internal.Auth{
TenantID: config.TenantID,
PasswordCredentials: internal.PasswordCredentials{
@ -93,11 +103,20 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
},
}
client, err := internal.NewClient(config.Region, auth, config.HTTPClient)
tokens, err := identifier.GetToken(context.TODO(), auth)
if err != nil {
return nil, fmt.Errorf("conoha: failed to login: %w", err)
}
client, err := internal.NewClient(config.Region, tokens.Access.Token.ID)
if err != nil {
return nil, fmt.Errorf("conoha: failed to create client: %w", err)
}
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
}
return &DNSProvider{config: config, client: client}, nil
}
@ -107,10 +126,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("conoha: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
id, err := d.client.GetDomainID(authZone)
ctx := context.Background()
id, err := d.client.GetDomainID(ctx, authZone)
if err != nil {
return fmt.Errorf("conoha: failed to get domain ID: %w", err)
}
@ -122,7 +143,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
TTL: d.config.TTL,
}
err = d.client.CreateRecord(id, record)
err = d.client.CreateRecord(ctx, id, record)
if err != nil {
return fmt.Errorf("conoha: failed to create record: %w", err)
}
@ -136,20 +157,22 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("conoha: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
domID, err := d.client.GetDomainID(authZone)
ctx := context.Background()
domID, err := d.client.GetDomainID(ctx, authZone)
if err != nil {
return fmt.Errorf("conoha: failed to get domain ID: %w", err)
}
recID, err := d.client.GetRecordID(domID, info.EffectiveFQDN, "TXT", info.Value)
recID, err := d.client.GetRecordID(ctx, domID, info.EffectiveFQDN, "TXT", info.Value)
if err != nil {
return fmt.Errorf("conoha: failed to get record ID: %w", err)
}
err = d.client.DeleteRecord(domID, recID)
err = d.client.DeleteRecord(ctx, domID, recID)
if err != nil {
return fmt.Errorf("conoha: failed to delete record: %w", err)
}

View file

@ -29,7 +29,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIUsername: "api_username",
EnvAPIPassword: "api_password",
},
expected: `conoha: failed to create client: failed to login: HTTP request failed with status code 401: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`,
expected: `conoha: failed to login: unexpected status code: [status code: 401] body: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`,
},
{
desc: "missing credentials",
@ -99,7 +99,7 @@ func TestNewDNSProviderConfig(t *testing.T) {
}{
{
desc: "complete credentials, but login failed",
expected: `conoha: failed to create client: failed to login: HTTP request failed with status code 401: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`,
expected: `conoha: failed to login: unexpected status code: [status code: 401] body: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`,
tenant: "tenant_id",
username: "api_username",
password: "api_password",

View file

@ -2,121 +2,45 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const (
identityBaseURL = "https://identity.%s.conoha.io"
dnsServiceBaseURL = "https://dns-service.%s.conoha.io"
)
// IdentityRequest is an authentication request body.
type IdentityRequest struct {
Auth Auth `json:"auth"`
}
// Auth is an authentication information.
type Auth struct {
TenantID string `json:"tenantId"`
PasswordCredentials PasswordCredentials `json:"passwordCredentials"`
}
// PasswordCredentials is API-user's credentials.
type PasswordCredentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
// IdentityResponse is an authentication response body.
type IdentityResponse struct {
Access Access `json:"access"`
}
// Access is an identity information.
type Access struct {
Token Token `json:"token"`
}
// Token is an api access token.
type Token struct {
ID string `json:"id"`
}
// DomainListResponse is a response of a domain listing request.
type DomainListResponse struct {
Domains []Domain `json:"domains"`
}
// Domain is a hosted domain entry.
type Domain struct {
ID string `json:"id"`
Name string `json:"name"`
}
// RecordListResponse is a response of record listing request.
type RecordListResponse struct {
Records []Record `json:"records"`
}
// Record is a record entry.
type Record struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Type string `json:"type"`
Data string `json:"data"`
TTL int `json:"ttl"`
}
const dnsServiceBaseURL = "https://dns-service.%s.conoha.io"
// Client is a ConoHa API client.
type Client struct {
token string
endpoint string
httpClient *http.Client
baseURL *url.URL
HTTPClient *http.Client
}
// NewClient returns a client instance logged into the ConoHa service.
func NewClient(region string, auth Auth, httpClient *http.Client) (*Client, error) {
if httpClient == nil {
httpClient = &http.Client{}
}
c := &Client{httpClient: httpClient}
c.endpoint = fmt.Sprintf(identityBaseURL, region)
identity, err := c.getIdentity(auth)
if err != nil {
return nil, fmt.Errorf("failed to login: %w", err)
}
c.token = identity.Access.Token.ID
c.endpoint = fmt.Sprintf(dnsServiceBaseURL, region)
return c, nil
}
func (c *Client) getIdentity(auth Auth) (*IdentityResponse, error) {
req := &IdentityRequest{Auth: auth}
identity := &IdentityResponse{}
err := c.do(http.MethodPost, "/v2.0/tokens", req, identity)
func NewClient(region string, token string) (*Client, error) {
baseURL, err := url.Parse(fmt.Sprintf(dnsServiceBaseURL, region))
if err != nil {
return nil, err
}
return identity, nil
return &Client{
token: token,
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}, nil
}
// GetDomainID returns an ID of specified domain.
func (c *Client) GetDomainID(domainName string) (string, error) {
domainList := &DomainListResponse{}
err := c.do(http.MethodGet, "/v1/domains", nil, domainList)
func (c *Client) GetDomainID(ctx context.Context, domainName string) (string, error) {
domainList, err := c.getDomains(ctx)
if err != nil {
return "", err
}
@ -126,14 +50,32 @@ func (c *Client) GetDomainID(domainName string) (string, error) {
return domain.ID, nil
}
}
return "", fmt.Errorf("no such domain: %s", domainName)
}
// GetRecordID returns an ID of specified record.
func (c *Client) GetRecordID(domainID, recordName, recordType, data string) (string, error) {
recordList := &RecordListResponse{}
// https://www.conoha.jp/docs/paas-dns-list-domains.php
func (c *Client) getDomains(ctx context.Context) (*DomainListResponse, error) {
endpoint := c.baseURL.JoinPath("v1", "domains")
err := c.do(http.MethodGet, fmt.Sprintf("/v1/domains/%s/records", domainID), nil, recordList)
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
domainList := &DomainListResponse{}
err = c.do(req, domainList)
if err != nil {
return nil, err
}
return domainList, nil
}
// GetRecordID returns an ID of specified record.
func (c *Client) GetRecordID(ctx context.Context, domainID, recordName, recordType, data string) (string, error) {
recordList, err := c.getRecords(ctx, domainID)
if err != nil {
return "", err
}
@ -143,63 +85,119 @@ func (c *Client) GetRecordID(domainID, recordName, recordType, data string) (str
return record.ID, nil
}
}
return "", errors.New("no such record")
}
// https://www.conoha.jp/docs/paas-dns-list-records-in-a-domain.php
func (c *Client) getRecords(ctx context.Context, domainID string) (*RecordListResponse, error) {
endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
recordList := &RecordListResponse{}
err = c.do(req, recordList)
if err != nil {
return nil, err
}
return recordList, nil
}
// CreateRecord adds new record.
func (c *Client) CreateRecord(domainID string, record Record) error {
return c.do(http.MethodPost, fmt.Sprintf("/v1/domains/%s/records", domainID), record, nil)
func (c *Client) CreateRecord(ctx context.Context, domainID string, record Record) error {
_, err := c.createRecord(ctx, domainID, record)
return err
}
// https://www.conoha.jp/docs/paas-dns-create-record.php
func (c *Client) createRecord(ctx context.Context, domainID string, record Record) (*Record, error) {
endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return nil, err
}
newRecord := &Record{}
err = c.do(req, newRecord)
if err != nil {
return nil, err
}
return newRecord, nil
}
// DeleteRecord removes specified record.
func (c *Client) DeleteRecord(domainID, recordID string) error {
return c.do(http.MethodDelete, fmt.Sprintf("/v1/domains/%s/records/%s", domainID, recordID), nil, nil)
}
// https://www.conoha.jp/docs/paas-dns-delete-a-record.php
func (c *Client) DeleteRecord(ctx context.Context, domainID, recordID string) error {
endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records", recordID)
func (c *Client) do(method, path string, payload, result interface{}) error {
body := bytes.NewReader(nil)
if payload != nil {
bodyBytes, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(bodyBytes)
}
req, err := http.NewRequest(method, c.endpoint+path, body)
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
if c.token != "" {
req.Header.Set("X-Auth-Token", c.token)
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
respBody, err := io.ReadAll(resp.Body)
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
defer resp.Body.Close()
return fmt.Errorf("HTTP request failed with status code %d: %s", resp.StatusCode, string(respBody))
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
if result != nil {
respBody, err := io.ReadAll(resp.Body)
err = json.Unmarshal(raw, result)
if err != nil {
return err
}
defer resp.Body.Close()
return json.Unmarshal(respBody, result)
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View file

@ -1,30 +1,71 @@
package internal
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T) (*http.ServeMux, *Client) {
func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := &Client{
token: "secret",
endpoint: server.URL,
httpClient: server.Client(),
client, err := NewClient("tyo1", "secret")
require.NoError(t, err)
client.HTTPClient = server.Client()
client.baseURL, _ = url.Parse(server.URL)
return client, mux
}
return mux, client
func writeFixtureHandler(method, filename string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
writeFixture(rw, filename)
}
}
func writeBodyHandler(method, content string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
_, err := fmt.Fprint(rw, content)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}
}
func writeFixture(rw http.ResponseWriter, filename string) {
file, err := os.Open(filepath.Join("fixtures", filename))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = file.Close() }()
_, _ = io.Copy(rw, file)
}
func TestClient_GetDomainID(t *testing.T) {
@ -42,91 +83,30 @@ func TestClient_GetDomainID(t *testing.T) {
{
desc: "success",
domainName: "domain1.com.",
handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
return
}
content := `
{
"domains":[
{
"id": "09494b72-b65b-4297-9efb-187f65a0553e",
"name": "domain1.com.",
"ttl": 3600,
"serial": 1351800668,
"email": "nsadmin@example.org",
"gslb": 0,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": null,
"description": "memo"
},
{
"id": "cf661142-e577-40b5-b3eb-75795cdc0cd7",
"name": "domain2.com.",
"ttl": 7200,
"serial": 1351800670,
"email": "nsadmin2@example.org",
"gslb": 1,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": "2012-12-01T20:11:08.000000",
"description": "memomemo"
}
]
}
`
_, err := fmt.Fprint(rw, content)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
},
handler: writeFixtureHandler(http.MethodGet, "domains_GET.json"),
expected: expected{domainID: "09494b72-b65b-4297-9efb-187f65a0553e"},
},
{
desc: "non existing domain",
domainName: "domain1.com.",
handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
return
}
_, err := fmt.Fprint(rw, "{}")
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
},
handler: writeBodyHandler(http.MethodGet, "{}"),
expected: expected{error: true},
},
{
desc: "marshaling error",
domainName: "domain1.com.",
handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
return
}
_, err := fmt.Fprint(rw, "[]")
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
},
handler: writeBodyHandler(http.MethodGet, "[]"),
expected: expected{error: true},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
mux, client := setupTest(t)
client, mux := setupTest(t)
mux.Handle("/v1/domains", test.handler)
domainID, err := client.GetDomainID(test.domainName)
domainID, err := client.GetDomainID(context.Background(), test.domainName)
if test.expected.error {
require.Error(t, err)
@ -142,13 +122,13 @@ func TestClient_CreateRecord(t *testing.T) {
testCases := []struct {
desc string
handler http.HandlerFunc
expectError bool
assert require.ErrorAssertionFunc
}{
{
desc: "success",
handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
@ -157,31 +137,34 @@ func TestClient_CreateRecord(t *testing.T) {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
defer req.Body.Close()
defer func() { _ = req.Body.Close() }()
if string(raw) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` {
if string(bytes.TrimSpace(raw)) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` {
http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest)
return
}
writeFixture(rw, "domains-records_POST.json")
},
assert: require.NoError,
},
{
desc: "bad request",
handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
http.Error(rw, "OOPS", http.StatusBadRequest)
},
expectError: true,
assert: require.Error,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
mux, client := setupTest(t)
client, mux := setupTest(t)
mux.Handle("/v1/domains/lego/records", test.handler)
@ -194,13 +177,36 @@ func TestClient_CreateRecord(t *testing.T) {
TTL: 300,
}
err := client.CreateRecord(domainID, record)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
err := client.CreateRecord(context.Background(), domainID, record)
test.assert(t, err)
})
}
}
func TestClient_GetRecordID(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/v1/domains/89acac79-38e7-497d-807c-a011e1310438/records",
writeFixtureHandler(http.MethodGet, "domains-records_GET.json"))
recordID, err := client.GetRecordID(context.Background(), "89acac79-38e7-497d-807c-a011e1310438", "www.example.com.", "A", "15.185.172.153")
require.NoError(t, err)
assert.Equal(t, "2e32e609-3a4f-45ba-bdef-e50eacd345ad", recordID)
}
func TestClient_DeleteRecord(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/v1/domains/89acac79-38e7-497d-807c-a011e1310438/records/2e32e609-3a4f-45ba-bdef-e50eacd345ad", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodDelete {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
rw.WriteHeader(http.StatusOK)
})
err := client.DeleteRecord(context.Background(), "89acac79-38e7-497d-807c-a011e1310438", "2e32e609-3a4f-45ba-bdef-e50eacd345ad")
require.NoError(t, err)
}

View file

@ -0,0 +1,43 @@
{
"records": [
{
"id": "2e32e609-3a4f-45ba-bdef-e50eacd345ad",
"name": "www.example.com.",
"type": "A",
"ttl": 3600,
"created_at": "2012-11-02T19:56:26.000000",
"updated_at": "2012-11-04T13:22:36.000000",
"data": "15.185.172.153",
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"version": 1,
"gslb_region": "JP",
"gslb_weight": 250,
"gslb_check": 12300
},
{
"id": "8e9ecf3e-fb92-4a3a-a8ae-7596f167bea3",
"name": "host1.example.com.",
"type": "A",
"ttl": 3600,
"created_at": "2012-11-04T13:57:50.000000",
"updated_at": null,
"data": "15.185.172.154",
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"version": 1,
"gslb_region": "US",
"gslb_weight": 220,
"gslb_check": 12200
},
{
"id": "4ad19089-3e62-40f8-9482-17cc8ccb92cb",
"name": "web.example.com.",
"type": "CNAME",
"ttl": 3600,
"created_at": "2012-11-04T13:58:16.393735",
"updated_at": null,
"data": "www.example.com.",
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"version": 1
}
]
}

View file

@ -0,0 +1,13 @@
{
"id": "2e32e609-3a4f-45ba-bdef-e50eacd345ad",
"name": "www.example.com.",
"type": "A",
"created_at": "2012-11-02T19:56:26.366792",
"updated_at": null,
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"ttl": null,
"data": "192.0.2.3",
"gslb_check": 1,
"gslb_region": "JP",
"gslb_weight": 250
}

View file

@ -0,0 +1,26 @@
{
"domains":[
{
"id": "09494b72-b65b-4297-9efb-187f65a0553e",
"name": "domain1.com.",
"ttl": 3600,
"serial": 1351800668,
"email": "nsadmin@example.org",
"gslb": 0,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": null,
"description": "memo"
},
{
"id": "cf661142-e577-40b5-b3eb-75795cdc0cd7",
"name": "domain2.com.",
"ttl": 7200,
"serial": 1351800670,
"email": "nsadmin2@example.org",
"gslb": 1,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": "2012-12-01T20:11:08.000000",
"description": "memomemo"
}
]
}

View file

@ -0,0 +1,17 @@
{
"access": {
"token": {
"issued_at": "2015-05-19T07:08:21.927295",
"expires": "2015-05-20T07:08:21Z",
"id": "sample00d88246078f2bexample788f7",
"tenant": {
"name": "example00000000",
"enabled": true,
"tyo1_image_size": "550GB"
},
"endpoints_links": [],
"type": "mailhosting",
"name": "Mail Hosting Service"
}
}
}

View file

@ -0,0 +1,82 @@
package internal
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const identityBaseURL = "https://identity.%s.conoha.io"
type Identifier struct {
baseURL *url.URL
HTTPClient *http.Client
}
// NewIdentifier creates a new Identifier.
func NewIdentifier(region string) (*Identifier, error) {
baseURL, err := url.Parse(fmt.Sprintf(identityBaseURL, region))
if err != nil {
return nil, err
}
return &Identifier{
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}, nil
}
// GetToken gets valid token information.
// https://www.conoha.jp/docs/identity-post_tokens.php
func (c *Identifier) GetToken(ctx context.Context, auth Auth) (*IdentityResponse, error) {
endpoint := c.baseURL.JoinPath("v2.0", "tokens")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, &IdentityRequest{Auth: auth})
if err != nil {
return nil, err
}
identity := &IdentityResponse{}
err = c.do(req, identity)
if err != nil {
return nil, err
}
return identity, nil
}
func (c *Identifier) do(req *http.Request, result any) error {
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}

View file

@ -0,0 +1,41 @@
package internal
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewClient(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
identifier, err := NewIdentifier("tyo1")
require.NoError(t, err)
identifier.HTTPClient = server.Client()
identifier.baseURL, _ = url.Parse(server.URL)
mux.HandleFunc("/v2.0/tokens", writeFixtureHandler(http.MethodPost, "tokens_POST.json"))
auth := Auth{
TenantID: "487727e3921d44e3bfe7ebb337bf085e",
PasswordCredentials: PasswordCredentials{
Username: "ConoHa",
Password: "paSSword123456#$%",
},
}
token, err := identifier.GetToken(context.Background(), auth)
require.NoError(t, err)
expected := &IdentityResponse{Access: Access{Token: Token{ID: "sample00d88246078f2bexample788f7"}}}
assert.Equal(t, expected, token)
}

View file

@ -0,0 +1,58 @@
package internal
// IdentityRequest is an authentication request body.
type IdentityRequest struct {
Auth Auth `json:"auth"`
}
// Auth is an authentication information.
type Auth struct {
TenantID string `json:"tenantId"`
PasswordCredentials PasswordCredentials `json:"passwordCredentials"`
}
// PasswordCredentials is API-user's credentials.
type PasswordCredentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
// IdentityResponse is an authentication response body.
type IdentityResponse struct {
Access Access `json:"access"`
}
// Access is an identity information.
type Access struct {
Token Token `json:"token"`
}
// Token is an api access token.
type Token struct {
ID string `json:"id"`
}
// DomainListResponse is a response of a domain listing request.
type DomainListResponse struct {
Domains []Domain `json:"domains"`
}
// Domain is a hosted domain entry.
type Domain struct {
ID string `json:"id"`
Name string `json:"name"`
}
// RecordListResponse is a response of record listing request.
type RecordListResponse struct {
Records []Record `json:"records"`
}
// Record is a record entry.
type Record struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Type string `json:"type"`
Data string `json:"data"`
TTL int `json:"ttl"`
}

View file

@ -2,6 +2,7 @@
package constellix
import (
"context"
"errors"
"fmt"
"net/http"
@ -101,10 +102,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("constellix: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err)
return fmt.Errorf("constellix: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
dom, err := d.client.Domains.GetByName(dns01.UnFqdn(authZone))
ctx := context.Background()
dom, err := d.client.Domains.GetByName(ctx, dns01.UnFqdn(authZone))
if err != nil {
return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err)
}
@ -114,7 +117,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
return fmt.Errorf("constellix: %w", err)
}
records, err := d.client.TxtRecords.Search(dom.ID, internal.Exact, recordName)
records, err := d.client.TxtRecords.Search(ctx, dom.ID, internal.Exact, recordName)
if err != nil {
return fmt.Errorf("constellix: failed to search TXT records: %w", err)
}
@ -125,10 +128,10 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
// TXT record entry already existing
if len(records) == 1 {
return d.appendRecordValue(dom, records[0].ID, info.Value)
return d.appendRecordValue(ctx, dom, records[0].ID, info.Value)
}
err = d.createRecord(dom, info.EffectiveFQDN, recordName, info.Value)
err = d.createRecord(ctx, dom, info.EffectiveFQDN, recordName, info.Value)
if err != nil {
return fmt.Errorf("constellix: %w", err)
}
@ -142,10 +145,12 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("constellix: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err)
return fmt.Errorf("constellix: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
dom, err := d.client.Domains.GetByName(dns01.UnFqdn(authZone))
ctx := context.Background()
dom, err := d.client.Domains.GetByName(ctx, dns01.UnFqdn(authZone))
if err != nil {
return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err)
}
@ -155,7 +160,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("constellix: %w", err)
}
records, err := d.client.TxtRecords.Search(dom.ID, internal.Exact, recordName)
records, err := d.client.TxtRecords.Search(ctx, dom.ID, internal.Exact, recordName)
if err != nil {
return fmt.Errorf("constellix: failed to search TXT records: %w", err)
}
@ -168,7 +173,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil
}
record, err := d.client.TxtRecords.Get(dom.ID, records[0].ID)
record, err := d.client.TxtRecords.Get(ctx, dom.ID, records[0].ID)
if err != nil {
return fmt.Errorf("constellix: failed to get TXT records: %w", err)
}
@ -179,14 +184,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
// only 1 record value, the whole record must be deleted.
if len(record.Value) == 1 {
_, err = d.client.TxtRecords.Delete(dom.ID, record.ID)
_, err = d.client.TxtRecords.Delete(ctx, dom.ID, record.ID)
if err != nil {
return fmt.Errorf("constellix: failed to delete TXT records: %w", err)
}
return nil
}
err = d.removeRecordValue(dom, record, info.Value)
err = d.removeRecordValue(ctx, dom, record, info.Value)
if err != nil {
return fmt.Errorf("constellix: %w", err)
}
@ -194,7 +199,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil
}
func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value string) error {
func (d *DNSProvider) createRecord(ctx context.Context, dom internal.Domain, fqdn, recordName, value string) error {
request := internal.RecordRequest{
Name: recordName,
TTL: d.config.TTL,
@ -203,7 +208,7 @@ func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value
},
}
_, err := d.client.TxtRecords.Create(dom.ID, request)
_, err := d.client.TxtRecords.Create(ctx, dom.ID, request)
if err != nil {
return fmt.Errorf("failed to create TXT record %s: %w", fqdn, err)
}
@ -211,8 +216,8 @@ func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value
return nil
}
func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, value string) error {
record, err := d.client.TxtRecords.Get(dom.ID, recordID)
func (d *DNSProvider) appendRecordValue(ctx context.Context, dom internal.Domain, recordID int64, value string) error {
record, err := d.client.TxtRecords.Get(ctx, dom.ID, recordID)
if err != nil {
return fmt.Errorf("failed to get TXT records: %w", err)
}
@ -227,7 +232,7 @@ func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, val
RoundRobin: append(record.RoundRobin, internal.RecordValue{Value: fmt.Sprintf(`%q`, value)}),
}
_, err = d.client.TxtRecords.Update(dom.ID, record.ID, request)
_, err = d.client.TxtRecords.Update(ctx, dom.ID, record.ID, request)
if err != nil {
return fmt.Errorf("failed to update TXT records: %w", err)
}
@ -235,7 +240,7 @@ func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, val
return nil
}
func (d *DNSProvider) removeRecordValue(dom internal.Domain, record *internal.Record, value string) error {
func (d *DNSProvider) removeRecordValue(ctx context.Context, dom internal.Domain, record *internal.Record, value string) error {
request := internal.RecordRequest{
Name: record.Name,
TTL: record.TTL,
@ -247,7 +252,7 @@ func (d *DNSProvider) removeRecordValue(dom internal.Domain, record *internal.Re
}
}
_, err := d.client.TxtRecords.Update(dom.ID, record.ID, request)
_, err := d.client.TxtRecords.Update(ctx, dom.ID, record.ID, request)
if err != nil {
return fmt.Errorf("failed to update TXT records: %w", err)
}

View file

@ -6,6 +6,9 @@ import (
"io"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const (
@ -28,7 +31,7 @@ type Client struct {
// NewClient Creates a Constellix client.
func NewClient(httpClient *http.Client) *Client {
if httpClient == nil {
httpClient = http.DefaultClient
httpClient = &http.Client{Timeout: 5 * time.Second}
}
client := &Client{
@ -48,13 +51,15 @@ type service struct {
}
// do sends an API request and returns the API response.
func (c *Client) do(req *http.Request, v interface{}) error {
func (c *Client) do(req *http.Request, result any) error {
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return err
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
err = checkResponse(resp)
@ -64,11 +69,11 @@ func (c *Client) do(req *http.Request, v interface{}) error {
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read body: %w", err)
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
if err = json.Unmarshal(raw, v); err != nil {
return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw))
if err = json.Unmarshal(raw, result); err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
@ -83,21 +88,21 @@ func checkResponse(resp *http.Response) error {
return nil
}
data, err := io.ReadAll(resp.Body)
if err == nil && data != nil {
msg := &APIError{StatusCode: resp.StatusCode}
raw, err := io.ReadAll(resp.Body)
if err == nil && raw != nil {
errAPI := &APIError{StatusCode: resp.StatusCode}
if json.Unmarshal(data, msg) != nil {
return fmt.Errorf("API error: status code: %d: %v", resp.StatusCode, string(data))
if json.Unmarshal(raw, errAPI) != nil {
return fmt.Errorf("API error: status code: %d: %v", resp.StatusCode, string(raw))
}
switch resp.StatusCode {
case http.StatusNotFound:
return &NotFound{APIError: msg}
return &NotFound{APIError: errAPI}
case http.StatusBadRequest:
return &BadRequest{APIError: msg}
return &BadRequest{APIError: errAPI}
default:
return msg
return errAPI
}
}

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"errors"
"fmt"
"net/http"
@ -13,15 +14,15 @@ type DomainService service
// GetAll domains.
// https://api-docs.constellix.com/?version=latest#484c3f21-d724-4ee4-a6fa-ab22c8eb9e9b
func (s *DomainService) GetAll(params *PaginationParameters) ([]Domain, error) {
func (s *DomainService) GetAll(ctx context.Context, params *PaginationParameters) ([]Domain, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains")
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
if params != nil {
@ -42,8 +43,8 @@ func (s *DomainService) GetAll(params *PaginationParameters) ([]Domain, error) {
}
// GetByName Gets domain by name.
func (s *DomainService) GetByName(domainName string) (Domain, error) {
domains, err := s.Search(Exact, domainName)
func (s *DomainService) GetByName(ctx context.Context, domainName string) (Domain, error) {
domains, err := s.Search(ctx, Exact, domainName)
if err != nil {
return Domain{}, err
}
@ -61,15 +62,15 @@ func (s *DomainService) GetByName(domainName string) (Domain, error) {
// Search searches for a domain by name.
// https://api-docs.constellix.com/?version=latest#3d7b2679-2209-49f3-b011-b7d24e512008
func (s *DomainService) Search(filter searchFilter, value string) ([]Domain, error) {
func (s *DomainService) Search(ctx context.Context, filter searchFilter, value string) ([]Domain, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", "search")
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
query := req.URL.Query()

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"io"
"net/http"
"net/http/httptest"
@ -47,7 +48,7 @@ func TestDomainService_GetAll(t *testing.T) {
}
})
data, err := client.Domains.GetAll(nil)
data, err := client.Domains.GetAll(context.Background(), nil)
require.NoError(t, err)
expected := []Domain{
@ -83,7 +84,7 @@ func TestDomainService_Search(t *testing.T) {
}
})
data, err := client.Domains.Search(Exact, "lego.wtf")
data, err := client.Domains.Search(context.Background(), Exact, "lego.wtf")
require.NoError(t, err)
expected := []Domain{

View file

@ -2,6 +2,7 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -14,20 +15,20 @@ type TxtRecordService service
// Create a TXT record.
// https://api-docs.constellix.com/?version=latest#22e24d5b-9ec0-49a7-b2b0-5ff0a28e71be
func (s *TxtRecordService) Create(domainID int64, record RecordRequest) ([]Record, error) {
body, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("failed to marshall request body: %w", err)
}
func (s *TxtRecordService) Create(ctx context.Context, domainID int64, record RecordRequest) ([]Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt")
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body))
body, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
var records []Record
@ -41,15 +42,15 @@ func (s *TxtRecordService) Create(domainID int64, record RecordRequest) ([]Recor
// GetAll TXT records.
// https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2
func (s *TxtRecordService) GetAll(domainID int64) ([]Record, error) {
func (s *TxtRecordService) GetAll(ctx context.Context, domainID int64) ([]Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt")
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
return nil, fmt.Errorf("failed to create endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
var records []Record
@ -63,15 +64,15 @@ func (s *TxtRecordService) GetAll(domainID int64) ([]Record, error) {
// Get a TXT record.
// https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2
func (s *TxtRecordService) Get(domainID, recordID int64) (*Record, error) {
func (s *TxtRecordService) Get(ctx context.Context, domainID, recordID int64) (*Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10))
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
var records Record
@ -85,20 +86,20 @@ func (s *TxtRecordService) Get(domainID, recordID int64) (*Record, error) {
// Update a TXT record.
// https://api-docs.constellix.com/?version=latest#d4e9ab2e-fac0-45a6-b0e4-cf62a2d2e3da
func (s *TxtRecordService) Update(domainID, recordID int64, record RecordRequest) (*SuccessMessage, error) {
body, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("failed to marshall request body: %w", err)
}
func (s *TxtRecordService) Update(ctx context.Context, domainID, recordID int64, record RecordRequest) (*SuccessMessage, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10))
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewReader(body))
body, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
var msg SuccessMessage
@ -112,15 +113,15 @@ func (s *TxtRecordService) Update(domainID, recordID int64, record RecordRequest
// Delete a TXT record.
// https://api-docs.constellix.com/?version=latest#135947f7-d6c8-481a-83c7-4d387b0bdf9e
func (s *TxtRecordService) Delete(domainID, recordID int64) (*SuccessMessage, error) {
func (s *TxtRecordService) Delete(ctx context.Context, domainID, recordID int64) (*SuccessMessage, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10))
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodDelete, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
var msg *SuccessMessage
@ -134,15 +135,15 @@ func (s *TxtRecordService) Delete(domainID, recordID int64) (*SuccessMessage, er
// Search searches for a TXT record by name.
// https://api-docs.constellix.com/?version=latest#81003e4f-bd3f-413f-a18d-6d9d18f10201
func (s *TxtRecordService) Search(domainID int64, filter searchFilter, value string) ([]Record, error) {
func (s *TxtRecordService) Search(ctx context.Context, domainID int64, filter searchFilter, value string) ([]Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", "search")
if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err)
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("unable to create request: %w", err)
}
query := req.URL.Query()

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"encoding/json"
"io"
"net/http"
@ -34,7 +35,7 @@ func TestTxtRecordService_Create(t *testing.T) {
}
})
records, err := client.TxtRecords.Create(12345, RecordRequest{})
records, err := client.TxtRecords.Create(context.Background(), 12345, RecordRequest{})
require.NoError(t, err)
recordsJSON, err := json.Marshal(records)
@ -69,7 +70,7 @@ func TestTxtRecordService_GetAll(t *testing.T) {
}
})
records, err := client.TxtRecords.GetAll(12345)
records, err := client.TxtRecords.GetAll(context.Background(), 12345)
require.NoError(t, err)
recordsJSON, err := json.Marshal(records)
@ -104,7 +105,7 @@ func TestTxtRecordService_Get(t *testing.T) {
}
})
record, err := client.TxtRecords.Get(12345, 6789)
record, err := client.TxtRecords.Get(context.Background(), 12345, 6789)
require.NoError(t, err)
expected := &Record{
@ -145,7 +146,7 @@ func TestTxtRecordService_Update(t *testing.T) {
}
})
msg, err := client.TxtRecords.Update(12345, 6789, RecordRequest{})
msg, err := client.TxtRecords.Update(context.Background(), 12345, 6789, RecordRequest{})
require.NoError(t, err)
expected := &SuccessMessage{Success: "Record updated successfully"}
@ -168,7 +169,7 @@ func TestTxtRecordService_Delete(t *testing.T) {
}
})
msg, err := client.TxtRecords.Delete(12345, 6789)
msg, err := client.TxtRecords.Delete(context.Background(), 12345, 6789)
require.NoError(t, err)
expected := &SuccessMessage{Success: "Record deleted successfully"}
@ -198,7 +199,7 @@ func TestTxtRecordService_Search(t *testing.T) {
}
})
records, err := client.TxtRecords.Search(12345, Exact, "test")
records, err := client.TxtRecords.Search(context.Background(), 12345, Exact, "test")
require.NoError(t, err)
recordsJSON, err := json.Marshal(records)

View file

@ -106,7 +106,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("desec: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err)
return fmt.Errorf("desec: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
@ -156,7 +156,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("desec: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err)
return fmt.Errorf("desec: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)

View file

@ -128,12 +128,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("designate: couldn't get zone ID in Present: %w", err)
return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
zoneID, err := d.getZoneID(authZone)
if err != nil {
return fmt.Errorf("designate: %w", err)
return fmt.Errorf("designate: couldn't get zone ID in Present: %w", err)
}
// use mutex to prevent race condition between creating the record and verifying it
@ -168,7 +168,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return err
return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
zoneID, err := d.getZoneID(authZone)

View file

@ -286,6 +286,9 @@ func setupTestProvider(t *testing.T) string {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(`{
"access": {
@ -319,9 +322,6 @@ func setupTestProvider(t *testing.T) string {
w.WriteHeader(http.StatusOK)
})
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
return server.URL
}

View file

@ -1,131 +0,0 @@
package digitalocean
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/go-acme/lego/v4/challenge/dns01"
)
const defaultBaseURL = "https://api.digitalocean.com"
// txtRecordResponse represents a response from DO's API after making a TXT record.
type txtRecordResponse struct {
DomainRecord record `json:"domain_record"`
}
type record struct {
ID int `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Data string `json:"data,omitempty"`
TTL int `json:"ttl,omitempty"`
}
type apiError struct {
ID string `json:"id"`
Message string `json:"message"`
}
func (d *DNSProvider) removeTxtRecord(domain string, recordID int) error {
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain))
if err != nil {
return fmt.Errorf("could not determine zone for domain %q: %w", domain, err)
}
reqURL := fmt.Sprintf("%s/v2/domains/%s/records/%d", d.config.BaseURL, dns01.UnFqdn(authZone), recordID)
req, err := d.newRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return readError(req, resp)
}
return nil
}
func (d *DNSProvider) addTxtRecord(fqdn, value string) (*txtRecordResponse, error) {
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(fqdn))
if err != nil {
return nil, fmt.Errorf("could not determine zone for domain %q: %w", fqdn, err)
}
reqData := record{Type: "TXT", Name: fqdn, Data: value, TTL: d.config.TTL}
body, err := json.Marshal(reqData)
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/v2/domains/%s/records", d.config.BaseURL, dns01.UnFqdn(authZone))
req, err := d.newRequest(http.MethodPost, reqURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return nil, readError(req, resp)
}
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.New(toUnreadableBodyMessage(req, content))
}
// Everything looks good; but we'll need the ID later to delete the record
respData := &txtRecordResponse{}
err = json.Unmarshal(content, respData)
if err != nil {
return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(req, content))
}
return respData, nil
}
func (d *DNSProvider) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.config.AuthToken))
return req, nil
}
func readError(req *http.Request, resp *http.Response) error {
content, err := io.ReadAll(resp.Body)
if err != nil {
return errors.New(toUnreadableBodyMessage(req, content))
}
var errInfo apiError
err = json.Unmarshal(content, &errInfo)
if err != nil {
return fmt.Errorf("apiError unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content))
}
return fmt.Errorf("HTTP %d: %s: %s", resp.StatusCode, errInfo.ID, errInfo.Message)
}
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
}

View file

@ -2,14 +2,17 @@
package digitalocean
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/digitalocean/internal"
)
// Environment variables names.
@ -38,7 +41,7 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
return &Config{
BaseURL: env.GetOrDefaultString(EnvAPIUrl, defaultBaseURL),
BaseURL: env.GetOrDefaultString(EnvAPIUrl, internal.DefaultBaseURL),
TTL: env.GetOrDefaultInt(EnvTTL, 30),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Second),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 5*time.Second),
@ -51,6 +54,8 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
client *internal.Client
recordIDs map[string]int
recordIDsMu sync.Mutex
}
@ -80,12 +85,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("digitalocean: credentials missing")
}
if config.BaseURL == "" {
config.BaseURL = defaultBaseURL
client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.AuthToken))
if config.BaseURL != "" {
var err error
client.BaseURL, err = url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("digitalocean: %w", err)
}
}
return &DNSProvider{
config: config,
client: client,
recordIDs: make(map[string]int),
}, nil
}
@ -100,7 +112,14 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
respData, err := d.addTxtRecord(info.EffectiveFQDN, info.Value)
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN))
if err != nil {
return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
record := internal.Record{Type: "TXT", Name: info.EffectiveFQDN, Data: info.Value, TTL: d.config.TTL}
respData, err := d.client.AddTxtRecord(context.Background(), authZone, record)
if err != nil {
return fmt.Errorf("digitalocean: %w", err)
}
@ -118,7 +137,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("digitalocean: %w", err)
return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
// get the record's unique ID from when we created it
@ -129,7 +148,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("digitalocean: unknown record ID for '%s'", info.EffectiveFQDN)
}
err = d.removeTxtRecord(authZone, recordID)
err = d.client.RemoveTxtRecord(context.Background(), authZone, recordID)
if err != nil {
return fmt.Errorf("digitalocean: %w", err)
}

View file

@ -1,6 +1,7 @@
package digitalocean
import (
"bytes"
"fmt"
"io"
"net/http"
@ -115,6 +116,7 @@ func TestDNSProvider_Present(t *testing.T) {
mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method, "method")
assert.Equal(t, "application/json", r.Header.Get("Accept"), "Accept")
assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type")
assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization")
@ -125,7 +127,7 @@ func TestDNSProvider_Present(t *testing.T) {
}
expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}`
assert.Equal(t, expectedReqBody, string(reqBody))
assert.Equal(t, expectedReqBody, string(bytes.TrimSpace(reqBody)))
w.WriteHeader(http.StatusCreated)
_, err = fmt.Fprintf(w, `{
@ -157,7 +159,7 @@ func TestDNSProvider_CleanUp(t *testing.T) {
assert.Equal(t, "/v2/domains/example.com/records/1234567", r.URL.Path, "Path")
// NOTE: Even though the body is empty, DigitalOcean API docs still show setting this Content-Type...
assert.Equal(t, "application/json", r.Header.Get("Accept"), "Accept")
assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type")
assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization")

View file

@ -0,0 +1,142 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"golang.org/x/oauth2"
)
// DefaultBaseURL default API endpoint.
const DefaultBaseURL = "https://api.digitalocean.com"
// Client the Digital Ocean API client.
type Client struct {
BaseURL *url.URL
httpClient *http.Client
}
// NewClient creates a new Client.
func NewClient(hc *http.Client) *Client {
baseURL, _ := url.Parse(DefaultBaseURL)
if hc == nil {
hc = &http.Client{Timeout: 5 * time.Second}
}
return &Client{BaseURL: baseURL, httpClient: hc}
}
func (c *Client) AddTxtRecord(ctx context.Context, zone string, record Record) (*TxtRecordResponse, error) {
endpoint := c.BaseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return nil, err
}
respData := &TxtRecordResponse{}
err = c.do(req, respData)
if err != nil {
return nil, err
}
return respData, nil
}
func (c *Client) RemoveTxtRecord(ctx context.Context, zone string, recordID int) error {
endpoint := c.BaseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records", strconv.Itoa(recordID))
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= http.StatusBadRequest {
return parseError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
// NOTE: Even though the body is empty, DigitalOcean API docs still show setting this Content-Type...
req.Header.Set("Content-Type", "application/json")
return req, nil
}
func parseError(req *http.Request, resp *http.Response) error {
raw, _ := io.ReadAll(resp.Body)
var errInfo APIError
err := json.Unmarshal(raw, &errInfo)
if err != nil {
return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)
}
return fmt.Errorf("[status code %d] %w", resp.StatusCode, errInfo)
}
func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
if client == nil {
client = &http.Client{Timeout: 5 * time.Second}
}
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
Base: client.Transport,
}
return client
}

View file

@ -0,0 +1,139 @@
package internal
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, pattern string, handler http.HandlerFunc) *Client {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient(OAuthStaticAccessToken(server.Client(), "secret"))
client.BaseURL, _ = url.Parse(server.URL)
mux.HandleFunc(pattern, handler)
return client
}
func checkHeader(req *http.Request, name, value string) error {
val := req.Header.Get(name)
if val != value {
return fmt.Errorf("invalid header value, got: %s want %s", val, value)
}
return nil
}
func writeFixture(rw http.ResponseWriter, filename string) {
file, err := os.Open(filepath.Join("fixtures", filename))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = file.Close() }()
_, _ = io.Copy(rw, file)
}
func TestClient_AddTxtRecord(t *testing.T) {
client := setupTest(t, "/v2/domains/example.com/records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed)
return
}
err := checkHeader(req, "Accept", "application/json")
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
err = checkHeader(req, "Content-Type", "application/json")
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
err = checkHeader(req, "Authorization", "Bearer secret")
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
reqBody, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}`
if expectedReqBody != string(bytes.TrimSpace(reqBody)) {
http.Error(rw, fmt.Sprintf("unexpected request body: %s", string(bytes.TrimSpace(reqBody))), http.StatusBadRequest)
return
}
rw.WriteHeader(http.StatusCreated)
writeFixture(rw, "domains-records_POST.json")
})
record := Record{
Type: "TXT",
Name: "_acme-challenge.example.com.",
Data: "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI",
TTL: 30,
}
newRecord, err := client.AddTxtRecord(context.Background(), "example.com", record)
require.NoError(t, err)
expected := &TxtRecordResponse{DomainRecord: Record{
ID: 1234567,
Type: "TXT",
Name: "_acme-challenge",
Data: "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI",
TTL: 0,
}}
assert.Equal(t, expected, newRecord)
}
func TestClient_RemoveTxtRecord(t *testing.T) {
client := setupTest(t, "/v2/domains/example.com/records/1234567", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodDelete {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed)
return
}
err := checkHeader(req, "Accept", "application/json")
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
err = checkHeader(req, "Authorization", "Bearer secret")
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
rw.WriteHeader(http.StatusNoContent)
})
err := client.RemoveTxtRecord(context.Background(), "example.com", 1234567)
require.NoError(t, err)
}

View file

@ -0,0 +1,11 @@
{
"domain_record": {
"id": 1234567,
"type": "TXT",
"name": "_acme-challenge",
"data": "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI",
"priority": null,
"port": null,
"weight": null
}
}

View file

@ -0,0 +1,25 @@
package internal
import "fmt"
// TxtRecordResponse represents a response from DO's API after making a TXT record.
type TxtRecordResponse struct {
DomainRecord Record `json:"domain_record"`
}
type Record struct {
ID int `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Data string `json:"data,omitempty"`
TTL int `json:"ttl,omitempty"`
}
type APIError struct {
ID string `json:"id"`
Message string `json:"message"`
}
func (a APIError) Error() string {
return fmt.Sprintf("%s: %s", a.ID, a.Message)
}

View file

@ -126,7 +126,7 @@ import (
// NewDNSChallengeProviderByName Factory for DNS providers.
func NewDNSChallengeProviderByName(name string) (challenge.Provider, error) {
switch name {
case "acme-dns":
case "acme-dns": // TODO(ldez): remove "-" in v5
return acmedns.NewDNSProvider()
case "alidns":
return alidns.NewDNSProvider()

View file

@ -2,6 +2,7 @@
package dnshomede
import (
"context"
"errors"
"fmt"
"net/http"
@ -99,7 +100,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, _, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Add(dns01.UnFqdn(info.EffectiveFQDN), info.Value)
err := d.client.Add(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value)
if err != nil {
return fmt.Errorf("dnshomede: %w", err)
}
@ -111,7 +112,7 @@ func (d *DNSProvider) Present(domain, _, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Remove(dns01.UnFqdn(info.EffectiveFQDN), info.Value)
err := d.client.Remove(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value)
if err != nil {
return fmt.Errorf("dnshomede: %w", err)
}

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"errors"
"fmt"
"io"
@ -9,6 +10,8 @@ import (
"strings"
"sync"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const (
@ -22,8 +25,8 @@ const defaultBaseURL = "https://www.dnshome.de/dyndns.php"
// Client the dnsHome.de client.
type Client struct {
HTTPClient *http.Client
baseURL string
HTTPClient *http.Client
credentials map[string]string
credMu sync.Mutex
@ -40,75 +43,48 @@ func NewClient(credentials map[string]string) *Client {
// Add adds a TXT record.
// only one TXT record for ACME is allowed, so it will update the "current" TXT record.
func (c *Client) Add(hostname, value string) error {
func (c *Client) Add(ctx context.Context, hostname, value string) error {
domain := strings.TrimPrefix(hostname, "_acme-challenge.")
c.credMu.Lock()
password, ok := c.credentials[domain]
c.credMu.Unlock()
if !ok {
return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain)
}
return c.do(url.UserPassword(domain, password), addAction, value)
return c.doAction(ctx, domain, addAction, value)
}
// Remove removes a TXT record.
// only one TXT record for ACME is allowed, so it will remove "all" the TXT records.
func (c *Client) Remove(hostname, value string) error {
func (c *Client) Remove(ctx context.Context, hostname, value string) error {
domain := strings.TrimPrefix(hostname, "_acme-challenge.")
c.credMu.Lock()
password, ok := c.credentials[domain]
c.credMu.Unlock()
if !ok {
return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain)
return c.doAction(ctx, domain, removeAction, value)
}
return c.do(url.UserPassword(domain, password), removeAction, value)
}
func (c *Client) do(userInfo *url.Userinfo, action, value string) error {
if len(value) < 12 {
return fmt.Errorf("the TXT value must have more than 12 characters: %s", value)
}
apiEndpoint, err := url.Parse(c.baseURL)
func (c *Client) doAction(ctx context.Context, domain, action, value string) error {
endpoint, err := c.createEndpoint(domain, action, value)
if err != nil {
return err
}
apiEndpoint.User = userInfo
query := apiEndpoint.Query()
query.Set("acme", action)
query.Set("txt", value)
apiEndpoint.RawQuery = query.Encode()
req, err := http.NewRequest(http.MethodPost, apiEndpoint.String(), http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), http.NoBody)
if err != nil {
return err
return fmt.Errorf("unable to create request: %w", err)
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return err
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
all, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%d: %s", resp.StatusCode, string(all))
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
all, err := io.ReadAll(resp.Body)
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
output := string(all)
output := string(raw)
if !strings.HasPrefix(output, successCode) {
return errors.New(output)
@ -116,3 +92,31 @@ func (c *Client) do(userInfo *url.Userinfo, action, value string) error {
return nil
}
func (c *Client) createEndpoint(domain, action, value string) (*url.URL, error) {
if len(value) < 12 {
return nil, fmt.Errorf("the TXT value must have more than 12 characters: %s", value)
}
endpoint, err := url.Parse(c.baseURL)
if err != nil {
return nil, err
}
c.credMu.Lock()
password, ok := c.credentials[domain]
c.credMu.Unlock()
if !ok {
return nil, fmt.Errorf("domain %s not found in credentials, check your credentials map", domain)
}
endpoint.User = url.UserPassword(domain, password)
query := endpoint.Query()
query.Set("acme", action)
query.Set("txt", value)
endpoint.RawQuery = query.Encode()
return endpoint, nil
}

View file

@ -1,6 +1,7 @@
package internal
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -9,79 +10,55 @@ import (
"github.com/stretchr/testify/require"
)
func TestClient_Add(t *testing.T) {
txtValue := "123456789012"
func setupTest(t *testing.T, credentials map[string]string, handler http.HandlerFunc) *Client {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("/", handlerMock(addAction, txtValue))
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
credentials := map[string]string{
"example.org": "secret",
}
mux.HandleFunc("/", handler)
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Add("example.org", txtValue)
return client
}
func TestClient_Add(t *testing.T) {
txtValue := "123456789012"
client := setupTest(t, map[string]string{"example.org": "secret"}, handlerMock(addAction, txtValue))
err := client.Add(context.Background(), "example.org", txtValue)
require.NoError(t, err)
}
func TestClient_Add_error(t *testing.T) {
txtValue := "123456789012"
mux := http.NewServeMux()
mux.HandleFunc("/", handlerMock(addAction, txtValue))
server := httptest.NewServer(mux)
client := setupTest(t, map[string]string{"example.com": "secret"}, handlerMock(addAction, txtValue))
credentials := map[string]string{
"example.com": "secret",
}
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Add("example.org", txtValue)
err := client.Add(context.Background(), "example.org", txtValue)
require.Error(t, err)
}
func TestClient_Remove(t *testing.T) {
txtValue := "ABCDEFGHIJKL"
mux := http.NewServeMux()
mux.HandleFunc("/", handlerMock(removeAction, txtValue))
server := httptest.NewServer(mux)
client := setupTest(t, map[string]string{"example.org": "secret"}, handlerMock(removeAction, txtValue))
credentials := map[string]string{
"example.org": "secret",
}
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Remove("example.org", txtValue)
err := client.Remove(context.Background(), "example.org", txtValue)
require.NoError(t, err)
}
func TestClient_Remove_error(t *testing.T) {
txtValue := "ABCDEFGHIJKL"
mux := http.NewServeMux()
mux.HandleFunc("/", handlerMock(removeAction, txtValue))
server := httptest.NewServer(mux)
client := setupTest(t, map[string]string{"example.com": "secret"}, handlerMock(removeAction, txtValue))
credentials := map[string]string{
"example.com": "secret",
}
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Remove("example.org", txtValue)
err := client.Remove(context.Background(), "example.org", txtValue)
require.Error(t, err)
}

View file

@ -149,7 +149,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) getHostedZone(domain string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(domain)
if err != nil {
return "", err
return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err)
}
accountID, err := d.getAccountID()

View file

@ -2,10 +2,12 @@
package dnsmadeeasy
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
@ -86,12 +88,12 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
var baseURL string
if config.Sandbox {
baseURL = "https://api.sandbox.dnsmadeeasy.com/V2.0"
baseURL = internal.DefaultSandboxBaseURL
} else {
if config.BaseURL == "" {
baseURL = internal.DefaultProdBaseURL
} else {
if len(config.BaseURL) > 0 {
baseURL = config.BaseURL
} else {
baseURL = "https://api.dnsmadeeasy.com/V2.0"
}
}
@ -101,7 +103,10 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
}
client.HTTPClient = config.HTTPClient
client.BaseURL = baseURL
client.BaseURL, err = url.Parse(baseURL)
if err != nil {
return nil, err
}
return &DNSProvider{
client: client,
@ -115,11 +120,13 @@ func (d *DNSProvider) Present(domainName, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to find zone for %s: %w", info.EffectiveFQDN, err)
return fmt.Errorf("dnsmadeeasy: could not find zone for domain %q (%s): %w", domainName, info.EffectiveFQDN, err)
}
ctx := context.Background()
// fetch the domain details
domain, err := d.client.GetDomain(authZone)
domain, err := d.client.GetDomain(ctx, authZone)
if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err)
}
@ -128,7 +135,7 @@ func (d *DNSProvider) Present(domainName, token, keyAuth string) error {
name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1)
record := &internal.Record{Type: "TXT", Name: name, Value: info.Value, TTL: d.config.TTL}
err = d.client.CreateRecord(domain, record)
err = d.client.CreateRecord(ctx, domain, record)
if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to create record for %s: %w", name, err)
}
@ -141,18 +148,20 @@ func (d *DNSProvider) CleanUp(domainName, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to find zone for %s: %w", info.EffectiveFQDN, err)
return fmt.Errorf("dnsmadeeasy: could not find zone for domain %q (%s): %w", domainName, info.EffectiveFQDN, err)
}
ctx := context.Background()
// fetch the domain details
domain, err := d.client.GetDomain(authZone)
domain, err := d.client.GetDomain(ctx, authZone)
if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err)
}
// find matching records
name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1)
records, err := d.client.GetRecords(domain, name, "TXT")
records, err := d.client.GetRecords(ctx, domain, name, "TXT")
if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to get records for domain %s: %w", domain.Name, err)
}
@ -160,7 +169,7 @@ func (d *DNSProvider) CleanUp(domainName, token, keyAuth string) error {
// delete records
var lastError error
for _, record := range *records {
err = d.client.DeleteRecord(record)
err = d.client.DeleteRecord(ctx, record)
if err != nil {
lastError = fmt.Errorf("dnsmadeeasy: unable to delete record [id=%d, name=%s]: %w", record.ID, record.Name, err)
}

Some files were not shown because too many files have changed in this diff Show more