chore: refactor clients (#1868)

This commit is contained in:
Ludovic Fernandez 2023-05-05 09:49:38 +02:00 committed by GitHub
parent 0c1303c1bd
commit aeec5be129
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
431 changed files with 16635 additions and 10176 deletions

View file

@ -161,7 +161,7 @@ issues:
linters: linters:
- gocyclo - gocyclo
- funlen - funlen
- path: providers/dns/checkdomain/client.go - path: providers/dns/checkdomain/internal/types.go
text: '`payed` is a misspelling of `paid`' text: '`payed` is a misspelling of `paid`'
- path: providers/dns/namecheap/namecheap_test.go - path: providers/dns/namecheap/namecheap_test.go
text: 'cognitive complexity (\d+) of func `TestDNSProvider_getHosts` is high' text: 'cognitive complexity (\d+) of func `TestDNSProvider_getHosts` is high'
@ -174,7 +174,7 @@ issues:
text: 'yodaStyleExpr' text: 'yodaStyleExpr'
- path: providers/dns/dns_providers.go - path: providers/dns/dns_providers.go
text: 'Function name: NewDNSChallengeProviderByName,' text: 'Function name: NewDNSChallengeProviderByName,'
- path: providers/dns/sakuracloud/client.go - path: providers/dns/sakuracloud/wrapper.go
text: 'mu is a global variable' text: 'mu is a global variable'
- path: providers/dns/hosttech/internal/client_test.go - path: providers/dns/hosttech/internal/client_test.go
text: 'Duplicate words \(0\) found' text: 'Duplicate words \(0\) found'

View file

@ -51,7 +51,7 @@ Detailed documentation is available [here](https://go-acme.github.io/lego/dns).
|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------| |---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|
| [Akamai EdgeDNS](https://go-acme.github.io/lego/dns/edgedns/) | [Alibaba Cloud DNS](https://go-acme.github.io/lego/dns/alidns/) | [all-inkl](https://go-acme.github.io/lego/dns/allinkl/) | [Amazon Lightsail](https://go-acme.github.io/lego/dns/lightsail/) | | [Akamai EdgeDNS](https://go-acme.github.io/lego/dns/edgedns/) | [Alibaba Cloud DNS](https://go-acme.github.io/lego/dns/alidns/) | [all-inkl](https://go-acme.github.io/lego/dns/allinkl/) | [Amazon Lightsail](https://go-acme.github.io/lego/dns/lightsail/) |
| [Amazon Route 53](https://go-acme.github.io/lego/dns/route53/) | [ArvanCloud](https://go-acme.github.io/lego/dns/arvancloud/) | [Aurora DNS](https://go-acme.github.io/lego/dns/auroradns/) | [Autodns](https://go-acme.github.io/lego/dns/autodns/) | | [Amazon Route 53](https://go-acme.github.io/lego/dns/route53/) | [ArvanCloud](https://go-acme.github.io/lego/dns/arvancloud/) | [Aurora DNS](https://go-acme.github.io/lego/dns/auroradns/) | [Autodns](https://go-acme.github.io/lego/dns/autodns/) |
| [Azure](https://go-acme.github.io/lego/dns/azure/) | [Bindman](https://go-acme.github.io/lego/dns/bindman/) | [Bluecat](https://go-acme.github.io/lego/dns/bluecat/) | [BRANDIT](https://go-acme.github.io/lego/dns/brandit/) | | [Azure](https://go-acme.github.io/lego/dns/azure/) | [Bindman](https://go-acme.github.io/lego/dns/bindman/) | [Bluecat](https://go-acme.github.io/lego/dns/bluecat/) | [Brandit](https://go-acme.github.io/lego/dns/brandit/) |
| [Bunny](https://go-acme.github.io/lego/dns/bunny/) | [Checkdomain](https://go-acme.github.io/lego/dns/checkdomain/) | [Civo](https://go-acme.github.io/lego/dns/civo/) | [CloudDNS](https://go-acme.github.io/lego/dns/clouddns/) | | [Bunny](https://go-acme.github.io/lego/dns/bunny/) | [Checkdomain](https://go-acme.github.io/lego/dns/checkdomain/) | [Civo](https://go-acme.github.io/lego/dns/civo/) | [CloudDNS](https://go-acme.github.io/lego/dns/clouddns/) |
| [Cloudflare](https://go-acme.github.io/lego/dns/cloudflare/) | [ClouDNS](https://go-acme.github.io/lego/dns/cloudns/) | [CloudXNS](https://go-acme.github.io/lego/dns/cloudxns/) | [ConoHa](https://go-acme.github.io/lego/dns/conoha/) | | [Cloudflare](https://go-acme.github.io/lego/dns/cloudflare/) | [ClouDNS](https://go-acme.github.io/lego/dns/cloudns/) | [CloudXNS](https://go-acme.github.io/lego/dns/cloudxns/) | [ConoHa](https://go-acme.github.io/lego/dns/conoha/) |
| [Constellix](https://go-acme.github.io/lego/dns/constellix/) | [deSEC.io](https://go-acme.github.io/lego/dns/desec/) | [Designate DNSaaS for Openstack](https://go-acme.github.io/lego/dns/designate/) | [Digital Ocean](https://go-acme.github.io/lego/dns/digitalocean/) | | [Constellix](https://go-acme.github.io/lego/dns/constellix/) | [deSEC.io](https://go-acme.github.io/lego/dns/desec/) | [Designate DNSaaS for Openstack](https://go-acme.github.io/lego/dns/designate/) | [Digital Ocean](https://go-acme.github.io/lego/dns/digitalocean/) |

View file

@ -335,7 +335,7 @@ func displayDNSHelp(w io.Writer, name string) error {
case "brandit": case "brandit":
// generated from: providers/dns/brandit/brandit.toml // generated from: providers/dns/brandit/brandit.toml
ew.writeln(`Configuration for BRANDIT.`) ew.writeln(`Configuration for Brandit.`)
ew.writeln(`Code: 'brandit'`) ew.writeln(`Code: 'brandit'`)
ew.writeln(`Since: 'v4.11.0'`) ew.writeln(`Since: 'v4.11.0'`)
ew.writeln() ew.writeln()

View file

@ -1,5 +1,5 @@
--- ---
title: "BRANDIT" title: "Brandit"
date: 2019-03-03T16:39:46+01:00 date: 2019-03-03T16:39:46+01:00
draft: false draft: false
slug: brandit slug: brandit
@ -14,7 +14,7 @@ dnsprovider:
<!-- THIS DOCUMENTATION IS AUTO-GENERATED. PLEASE DO NOT EDIT. --> <!-- THIS DOCUMENTATION IS AUTO-GENERATED. PLEASE DO NOT EDIT. -->
Configuration for [BRANDIT](https://www.brandit.com/). Configuration for [Brandit](https://www.brandit.com/).
<!--more--> <!--more-->
@ -23,7 +23,7 @@ Configuration for [BRANDIT](https://www.brandit.com/).
- Since: v4.11.0 - Since: v4.11.0
Here is an example bash command using the BRANDIT provider: Here is an example bash command using the Brandit provider:
```bash ```bash
BRANDIT_API_KEY=xxxxxxxxxxxxxxxxxxxxx \ BRANDIT_API_KEY=xxxxxxxxxxxxxxxxxxxxx \

View file

@ -61,7 +61,7 @@ More information [here]({{< ref "dns#configuration-and-credentials" >}}).
## More information ## More information
- [API documentation](https://docs.otc.t-systems.com/en-us/dns/index.html) - [API documentation](https://docs.otc.t-systems.com/domain-name-service/api-ref/index.html)
<!-- THIS DOCUMENTATION IS AUTO-GENERATED. PLEASE DO NOT EDIT. --> <!-- THIS DOCUMENTATION IS AUTO-GENERATED. PLEASE DO NOT EDIT. -->
<!-- providers/dns/otc/otc.toml --> <!-- providers/dns/otc/otc.toml -->

14
go.mod
View file

@ -63,9 +63,9 @@ require (
github.com/vultr/govultr/v2 v2.17.2 github.com/vultr/govultr/v2 v2.17.2
github.com/yandex-cloud/go-genproto v0.0.0-20220805142335-27b56ddae16f github.com/yandex-cloud/go-genproto v0.0.0-20220805142335-27b56ddae16f
github.com/yandex-cloud/go-sdk v0.0.0-20220805164847-cf028e604997 github.com/yandex-cloud/go-sdk v0.0.0-20220805164847-cf028e604997
golang.org/x/crypto v0.5.0 golang.org/x/crypto v0.7.0
golang.org/x/net v0.7.0 golang.org/x/net v0.8.0
golang.org/x/oauth2 v0.5.0 golang.org/x/oauth2 v0.6.0
golang.org/x/time v0.3.0 golang.org/x/time v0.3.0
google.golang.org/api v0.111.0 google.golang.org/api v0.111.0
gopkg.in/ns1/ns1-go.v2 v2.6.5 gopkg.in/ns1/ns1-go.v2 v2.6.5
@ -126,10 +126,10 @@ require (
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
go.uber.org/ratelimit v0.2.0 // indirect go.uber.org/ratelimit v0.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/sys v0.5.0 // indirect golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.7.0 // indirect golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.1.12 // indirect golang.org/x/tools v0.6.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230223222841-637eb2293923 // indirect google.golang.org/genproto v0.0.0-20230223222841-637eb2293923 // indirect
google.golang.org/grpc v1.53.0 // indirect google.golang.org/grpc v1.53.0 // indirect

30
go.sum
View file

@ -595,8 +595,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -619,8 +619,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -650,14 +650,14 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.5.0 h1:HuArIo48skDwlrvM3sEdHXElYslAMsf3KwRkkW4MC4s= golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw=
golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -712,12 +712,12 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@ -726,8 +726,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -758,8 +758,8 @@ golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4X
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210114065538-d78b04bdf963/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210114065538-d78b04bdf963/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View file

@ -1,7 +1,6 @@
package wait package wait
import ( import (
"errors"
"fmt" "fmt"
"time" "time"
@ -18,9 +17,9 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er
select { select {
case <-timeUp: case <-timeUp:
if lastErr == nil { if lastErr == nil {
return errors.New("time limit exceeded") return fmt.Errorf("%s: time limit exceeded", msg)
} }
return fmt.Errorf("time limit exceeded: last error: %w", lastErr) return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr)
default: default:
} }

View file

@ -198,7 +198,7 @@ func (d *DNSProvider) getHostedZone(domain string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(domain) authZone, err := dns01.FindZoneByFqdn(domain)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err)
} }
var hostedZone alidns.DomainInDescribeDomains var hostedZone alidns.DomainInDescribeDomains

View file

@ -2,6 +2,7 @@
package allinkl package allinkl
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -49,6 +50,8 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface. // DNSProvider implements the challenge.Provider interface.
type DNSProvider struct { type DNSProvider struct {
config *Config config *Config
identifier *internal.Identifier
client *internal.Client client *internal.Client
recordIDs map[string]string recordIDs map[string]string
@ -80,7 +83,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("allinkl: missing credentials") return nil, errors.New("allinkl: missing credentials")
} }
client := internal.NewClient(config.Login, config.Password) identifier := internal.NewIdentifier(config.Login, config.Password)
if config.HTTPClient != nil {
identifier.HTTPClient = config.HTTPClient
}
client := internal.NewClient(config.Login)
if config.HTTPClient != nil { if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient client.HTTPClient = config.HTTPClient
@ -88,6 +97,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return &DNSProvider{ return &DNSProvider{
config: config, config: config,
identifier: identifier,
client: client, client: client,
recordIDs: make(map[string]string), recordIDs: make(map[string]string),
}, nil }, nil
@ -105,14 +115,18 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("allinkl: could not determine zone for domain %q: %w", domain, err) return fmt.Errorf("allinkl: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
credential, err := d.client.Authentication(60, true) ctx := context.Background()
credential, err := d.identifier.Authentication(ctx, 60, true)
if err != nil { if err != nil {
return fmt.Errorf("allinkl: %w", err) return fmt.Errorf("allinkl: %w", err)
} }
ctx = internal.WithContext(ctx, credential)
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil { if err != nil {
return fmt.Errorf("allinkl: %w", err) return fmt.Errorf("allinkl: %w", err)
@ -125,7 +139,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
RecordData: info.Value, RecordData: info.Value,
} }
recordID, err := d.client.AddDNSSettings(credential, record) recordID, err := d.client.AddDNSSettings(ctx, record)
if err != nil { if err != nil {
return fmt.Errorf("allinkl: %w", err) return fmt.Errorf("allinkl: %w", err)
} }
@ -141,11 +155,15 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
credential, err := d.client.Authentication(60, true) ctx := context.Background()
credential, err := d.identifier.Authentication(ctx, 60, true)
if err != nil { if err != nil {
return fmt.Errorf("allinkl: %w", err) return fmt.Errorf("allinkl: %w", err)
} }
ctx = internal.WithContext(ctx, credential)
// gets the record's unique ID from when we created it // gets the record's unique ID from when we created it
d.recordIDsMu.Lock() d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token] recordID, ok := d.recordIDs[token]
@ -154,7 +172,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("allinkl: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) return fmt.Errorf("allinkl: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
} }
_, err = d.client.DeleteDNSSettings(credential, recordID) _, err = d.client.DeleteDNSSettings(ctx, recordID)
if err != nil { if err != nil {
return fmt.Errorf("allinkl: %w", err) return fmt.Errorf("allinkl: %w", err)
} }

View file

@ -2,126 +2,64 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"encoding/xml"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
const ( const apiEndpoint = "https://kasapi.kasserver.com/soap/KasApi.php"
authEndpoint = "https://kasapi.kasserver.com/soap/KasAuth.php"
apiEndpoint = "https://kasapi.kasserver.com/soap/KasApi.php" type Authentication interface {
) Authentication(ctx context.Context, sessionLifetime int, sessionUpdateLifetime bool) (string, error)
}
// Client a KAS server client. // Client a KAS server client.
type Client struct { type Client struct {
login string login string
password string
authEndpoint string
apiEndpoint string
HTTPClient *http.Client
floodTime time.Time floodTime time.Time
muFloodTime sync.Mutex
baseURL string
HTTPClient *http.Client
} }
// NewClient creates a new Client. // NewClient creates a new Client.
func NewClient(login string, password string) *Client { func NewClient(login string) *Client {
return &Client{ return &Client{
login: login, login: login,
password: password, baseURL: apiEndpoint,
authEndpoint: authEndpoint,
apiEndpoint: apiEndpoint,
HTTPClient: &http.Client{Timeout: 10 * time.Second}, HTTPClient: &http.Client{Timeout: 10 * time.Second},
} }
} }
// Authentication Creates a credential token.
// - sessionLifetime: Validity of the token in seconds.
// - sessionUpdateLifetime: with `true` the session is extended with every request.
func (c Client) Authentication(sessionLifetime int, sessionUpdateLifetime bool) (string, error) {
sul := "N"
if sessionUpdateLifetime {
sul = "Y"
}
ar := AuthRequest{
Login: c.login,
AuthData: c.password,
AuthType: "plain",
SessionLifetime: sessionLifetime,
SessionUpdateLifetime: sul,
}
body, err := json.Marshal(ar)
if err != nil {
return "", fmt.Errorf("request marshal: %w", err)
}
payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAuthEnvelope, body)))
req, err := http.NewRequest(http.MethodPost, c.authEndpoint, bytes.NewReader(payload))
if err != nil {
return "", fmt.Errorf("request creation: %w", err)
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return "", fmt.Errorf("request execution: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("invalid status code: %d %s", resp.StatusCode, string(data))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("response read: %w", err)
}
var e KasAuthEnvelope
decoder := xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(data))})
err = decoder.Decode(&e)
if err != nil {
return "", fmt.Errorf("response xml decode: %w", err)
}
if e.Body.Fault != nil {
return "", e.Body.Fault
}
return e.Body.KasAuthResponse.Return.Text, nil
}
// GetDNSSettings Reading out the DNS settings of a zone. // GetDNSSettings Reading out the DNS settings of a zone.
// - zone: host zone. // - zone: host zone.
// - recordID: the ID of the resource record (optional). // - recordID: the ID of the resource record (optional).
func (c *Client) GetDNSSettings(credentialToken, zone, recordID string) ([]ReturnInfo, error) { func (c *Client) GetDNSSettings(ctx context.Context, zone, recordID string) ([]ReturnInfo, error) {
requestParams := map[string]string{"zone_host": zone} requestParams := map[string]string{"zone_host": zone}
if recordID != "" { if recordID != "" {
requestParams["record_id"] = recordID requestParams["record_id"] = recordID
} }
item, err := c.do(credentialToken, "get_dns_settings", requestParams) req, err := c.newRequest(ctx, "get_dns_settings", requestParams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
raw := getValue(item)
var g GetDNSSettingsAPIResponse var g GetDNSSettingsAPIResponse
err = mapstructure.Decode(raw, &g) err = c.do(req, &g)
if err != nil { if err != nil {
return nil, fmt.Errorf("response struct decode: %w", err) return nil, err
} }
c.updateFloodTime(g.Response.KasFloodDelay) c.updateFloodTime(g.Response.KasFloodDelay)
@ -130,18 +68,16 @@ func (c *Client) GetDNSSettings(credentialToken, zone, recordID string) ([]Retur
} }
// AddDNSSettings Creation of a DNS resource record. // AddDNSSettings Creation of a DNS resource record.
func (c *Client) AddDNSSettings(credentialToken string, record DNSRequest) (string, error) { func (c *Client) AddDNSSettings(ctx context.Context, record DNSRequest) (string, error) {
item, err := c.do(credentialToken, "add_dns_settings", record) req, err := c.newRequest(ctx, "add_dns_settings", record)
if err != nil { if err != nil {
return "", err return "", err
} }
raw := getValue(item)
var g AddDNSSettingsAPIResponse var g AddDNSSettingsAPIResponse
err = mapstructure.Decode(raw, &g) err = c.do(req, &g)
if err != nil { if err != nil {
return "", fmt.Errorf("response struct decode: %w", err) return "", err
} }
c.updateFloodTime(g.Response.KasFloodDelay) c.updateFloodTime(g.Response.KasFloodDelay)
@ -150,20 +86,18 @@ func (c *Client) AddDNSSettings(credentialToken string, record DNSRequest) (stri
} }
// DeleteDNSSettings Deleting a DNS Resource Record. // DeleteDNSSettings Deleting a DNS Resource Record.
func (c *Client) DeleteDNSSettings(credentialToken, recordID string) (bool, error) { func (c *Client) DeleteDNSSettings(ctx context.Context, recordID string) (bool, error) {
requestParams := map[string]string{"record_id": recordID} requestParams := map[string]string{"record_id": recordID}
item, err := c.do(credentialToken, "delete_dns_settings", requestParams) req, err := c.newRequest(ctx, "delete_dns_settings", requestParams)
if err != nil { if err != nil {
return false, err return false, err
} }
raw := getValue(item)
var g DeleteDNSSettingsAPIResponse var g DeleteDNSSettingsAPIResponse
err = mapstructure.Decode(raw, &g) err = c.do(req, &g)
if err != nil { if err != nil {
return false, fmt.Errorf("response struct decode: %w", err) return false, err
} }
c.updateFloodTime(g.Response.KasFloodDelay) c.updateFloodTime(g.Response.KasFloodDelay)
@ -171,65 +105,72 @@ func (c *Client) DeleteDNSSettings(credentialToken, recordID string) (bool, erro
return g.Response.ReturnInfo, nil return g.Response.ReturnInfo, nil
} }
func (c Client) do(credentialToken, action string, requestParams interface{}) (*Item, error) { func (c *Client) newRequest(ctx context.Context, action string, requestParams any) (*http.Request, error) {
time.Sleep(time.Until(c.floodTime))
ar := KasRequest{ ar := KasRequest{
Login: c.login, Login: c.login,
AuthType: "session", AuthType: "session",
AuthData: credentialToken, AuthData: getToken(ctx),
Action: action, Action: action,
RequestParams: requestParams, RequestParams: requestParams,
} }
body, err := json.Marshal(ar) body, err := json.Marshal(ar)
if err != nil { if err != nil {
return nil, fmt.Errorf("request marshal: %w", err) return nil, fmt.Errorf("failed to create request JSON body: %w", err)
} }
payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAPIEnvelope, body))) payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAPIEnvelope, body)))
req, err := http.NewRequest(http.MethodPost, c.apiEndpoint, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, bytes.NewReader(payload))
if err != nil { if err != nil {
return nil, fmt.Errorf("request creation: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
return req, nil
}
func (c *Client) do(req *http.Request, result any) error {
c.muFloodTime.Lock()
time.Sleep(time.Until(c.floodTime))
c.muFloodTime.Unlock()
resp, err := c.HTTPClient.Do(req) resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("request execution: %w", err) return errutils.NewHTTPDoError(req, err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body) return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
return nil, fmt.Errorf("invalid status code: %d %s", resp.StatusCode, string(data))
} }
data, err := io.ReadAll(resp.Body) envlp, err := decodeXML[KasAPIResponseEnvelope](resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("response read: %w", err) return err
} }
var e KasAPIResponseEnvelope if envlp.Body.Fault != nil {
decoder := xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(data))}) return envlp.Body.Fault
err = decoder.Decode(&e) }
raw := getValue(envlp.Body.KasAPIResponse.Return)
err = mapstructure.Decode(raw, result)
if err != nil { if err != nil {
return nil, fmt.Errorf("response xml decode: %w", err) return fmt.Errorf("response struct decode: %w", err)
} }
if e.Body.Fault != nil { return nil
return nil, e.Body.Fault
}
return e.Body.KasAPIResponse.Return, nil
} }
func (c *Client) updateFloodTime(delay float64) { func (c *Client) updateFloodTime(delay float64) {
c.muFloodTime.Lock()
c.floodTime = time.Now().Add(time.Duration(delay * float64(time.Second))) c.floodTime = time.Now().Add(time.Duration(delay * float64(time.Second)))
c.muFloodTime.Unlock()
} }
func getValue(item *Item) interface{} { func getValue(item *Item) any {
switch { switch {
case item.Raw != "": case item.Raw != "":
v, _ := strconv.ParseBool(item.Raw) v, _ := strconv.ParseBool(item.Raw)
@ -253,7 +194,7 @@ func getValue(item *Item) interface{} {
return getValue(item.Value) return getValue(item.Value)
case len(item.Items) > 0 && item.Type == "SOAP-ENC:Array": case len(item.Items) > 0 && item.Type == "SOAP-ENC:Array":
var v []interface{} var v []any
for _, i := range item.Items { for _, i := range item.Items {
v = append(v, getValue(i)) v = append(v, getValue(i))
} }
@ -261,7 +202,7 @@ func getValue(item *Item) interface{} {
return v return v
case len(item.Items) > 0: case len(item.Items) > 0:
v := map[string]interface{}{} v := map[string]any{}
for _, i := range item.Items { for _, i := range item.Items {
v[getKey(i)] = getValue(i) v[getKey(i)] = getValue(i)
} }

View file

@ -13,36 +13,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestClient_Authentication(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth.xml"))
client := NewClient("user", "secret")
client.authEndpoint = server.URL
credentialToken, err := client.Authentication(60, false)
require.NoError(t, err)
assert.Equal(t, "593959ca04f0de9689b586c6a647d15d", credentialToken)
}
func TestClient_Authentication_error(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth_fault.xml"))
client := NewClient("user", "secret")
client.authEndpoint = server.URL
_, err := client.Authentication(60, false)
require.Error(t, err)
}
func TestClient_GetDNSSettings(t *testing.T) { func TestClient_GetDNSSettings(t *testing.T) {
mux := http.NewServeMux() mux := http.NewServeMux()
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
@ -50,12 +20,10 @@ func TestClient_GetDNSSettings(t *testing.T) {
mux.HandleFunc("/", testHandler("get_dns_settings.xml")) mux.HandleFunc("/", testHandler("get_dns_settings.xml"))
client := NewClient("user", "secret") client := NewClient("user")
client.apiEndpoint = server.URL client.baseURL = server.URL
token := "sha1secret" records, err := client.GetDNSSettings(mockContext(), "example.com", "")
records, err := client.GetDNSSettings(token, "example.com", "")
require.NoError(t, err) require.NoError(t, err)
expected := []ReturnInfo{ expected := []ReturnInfo{
@ -134,10 +102,8 @@ func TestClient_AddDNSSettings(t *testing.T) {
mux.HandleFunc("/", testHandler("add_dns_settings.xml")) mux.HandleFunc("/", testHandler("add_dns_settings.xml"))
client := NewClient("user", "secret") client := NewClient("user")
client.apiEndpoint = server.URL client.baseURL = server.URL
token := "sha1secret"
record := DNSRequest{ record := DNSRequest{
ZoneHost: "42cnc.de.", ZoneHost: "42cnc.de.",
@ -146,7 +112,7 @@ func TestClient_AddDNSSettings(t *testing.T) {
RecordData: "abcdefgh", RecordData: "abcdefgh",
} }
recordID, err := client.AddDNSSettings(token, record) recordID, err := client.AddDNSSettings(mockContext(), record)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "57347444", recordID) assert.Equal(t, "57347444", recordID)
@ -159,12 +125,10 @@ func TestClient_DeleteDNSSettings(t *testing.T) {
mux.HandleFunc("/", testHandler("delete_dns_settings.xml")) mux.HandleFunc("/", testHandler("delete_dns_settings.xml"))
client := NewClient("user", "secret") client := NewClient("user")
client.apiEndpoint = server.URL client.baseURL = server.URL
token := "sha1secret" r, err := client.DeleteDNSSettings(mockContext(), "57347450")
r, err := client.DeleteDNSSettings(token, "57347450")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, r) assert.True(t, r)

View file

@ -0,0 +1,104 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
// authEndpoint represents the Identity API endpoint to call.
const authEndpoint = "https://kasapi.kasserver.com/soap/KasAuth.php"
type token string
const tokenKey token = "token"
// Identifier generates credential tokens.
type Identifier struct {
login string
password string
authEndpoint string
HTTPClient *http.Client
}
// NewIdentifier creates a new Identifier.
func NewIdentifier(login string, password string) *Identifier {
return &Identifier{
login: login,
password: password,
authEndpoint: authEndpoint,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}
}
// Authentication Creates a credential token.
// - sessionLifetime: Validity of the token in seconds.
// - sessionUpdateLifetime: with `true` the session is extended with every request.
func (c *Identifier) Authentication(ctx context.Context, sessionLifetime int, sessionUpdateLifetime bool) (string, error) {
sul := "N"
if sessionUpdateLifetime {
sul = "Y"
}
ar := AuthRequest{
Login: c.login,
AuthData: c.password,
AuthType: "plain",
SessionLifetime: sessionLifetime,
SessionUpdateLifetime: sul,
}
body, err := json.Marshal(ar)
if err != nil {
return "", fmt.Errorf("failed to create request JSON body: %w", err)
}
payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAuthEnvelope, body)))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.authEndpoint, bytes.NewReader(payload))
if err != nil {
return "", fmt.Errorf("unable to create request: %w", err)
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return "", errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
envlp, err := decodeXML[KasAuthEnvelope](resp.Body)
if err != nil {
return "", err
}
if envlp.Body.Fault != nil {
return "", envlp.Body.Fault
}
return envlp.Body.KasAuthResponse.Return.Text, nil
}
func WithContext(ctx context.Context, credential string) context.Context {
return context.WithValue(ctx, tokenKey, credential)
}
func getToken(ctx context.Context) string {
credential, ok := ctx.Value(tokenKey).(string)
if !ok {
return ""
}
return credential
}

View file

@ -0,0 +1,45 @@
package internal
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func mockContext() context.Context {
return context.WithValue(context.Background(), tokenKey, "593959ca04f0de9689b586c6a647d15d")
}
func TestIdentifier_Authentication(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth.xml"))
client := NewIdentifier("user", "secret")
client.authEndpoint = server.URL
credentialToken, err := client.Authentication(context.Background(), 60, false)
require.NoError(t, err)
assert.Equal(t, "593959ca04f0de9689b586c6a647d15d", credentialToken)
}
func TestIdentifier_Authentication_error(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", testHandler("auth_fault.xml"))
client := NewIdentifier("user", "secret")
client.authEndpoint = server.URL
_, err := client.Authentication(context.Background(), 60, false)
require.Error(t, err)
}

View file

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io"
) )
// Trimmer trim all XML fields. // Trimmer trim all XML fields.
@ -44,3 +45,18 @@ type Item struct {
Value *Item `xml:"value" json:"value,omitempty"` Value *Item `xml:"value" json:"value,omitempty"`
Items []*Item `xml:"item" json:"item,omitempty"` Items []*Item `xml:"item" json:"item,omitempty"`
} }
func decodeXML[T any](reader io.Reader) (*T, error) {
raw, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var result T
err = xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(raw))}).Decode(&result)
if err != nil {
return nil, fmt.Errorf("decode XML response: %w", err)
}
return &result, nil
}

View file

@ -35,7 +35,7 @@ type KasRequest struct {
// Action API function. // Action API function.
Action string `json:"kas_action,omitempty"` Action string `json:"kas_action,omitempty"`
// RequestParams Parameters to the API function. // RequestParams Parameters to the API function.
RequestParams interface{} `json:"KasRequestParams,omitempty"` RequestParams any `json:"KasRequestParams,omitempty"`
} }
type DNSRequest struct { type DNSRequest struct {
@ -64,7 +64,7 @@ type GetDNSSettingsResponse struct {
} }
type ReturnInfo struct { type ReturnInfo struct {
ID interface{} `json:"record_id,omitempty" mapstructure:"record_id"` ID any `json:"record_id,omitempty" mapstructure:"record_id"`
Zone string `json:"record_zone,omitempty" mapstructure:"record_zone"` Zone string `json:"record_zone,omitempty" mapstructure:"record_zone"`
Name string `json:"record_name,omitempty" mapstructure:"record_name"` Name string `json:"record_name,omitempty" mapstructure:"record_name"`
Type string `json:"record_type,omitempty" mapstructure:"record_type"` Type string `json:"record_type,omitempty" mapstructure:"record_type"`

View file

@ -2,6 +2,7 @@
package arvancloud package arvancloud
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -108,11 +109,13 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := getZone(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("arvancloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
authZone = dns01.UnFqdn(authZone)
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil { if err != nil {
return fmt.Errorf("arvancloud: %w", err) return fmt.Errorf("arvancloud: %w", err)
@ -131,7 +134,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
}, },
} }
newRecord, err := d.client.CreateRecord(authZone, record) newRecord, err := d.client.CreateRecord(context.Background(), authZone, record)
if err != nil { if err != nil {
return fmt.Errorf("arvancloud: failed to add TXT record: fqdn=%s: %w", info.EffectiveFQDN, err) return fmt.Errorf("arvancloud: failed to add TXT record: fqdn=%s: %w", info.EffectiveFQDN, err)
} }
@ -147,11 +150,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := getZone(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("arvancloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
authZone = dns01.UnFqdn(authZone)
// gets the record's unique ID from when we created it // gets the record's unique ID from when we created it
d.recordIDsMu.Lock() d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token] recordID, ok := d.recordIDs[token]
@ -160,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("arvancloud: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) return fmt.Errorf("arvancloud: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
} }
if err := d.client.DeleteRecord(authZone, recordID); err != nil { if err := d.client.DeleteRecord(context.Background(), authZone, recordID); err != nil {
return fmt.Errorf("arvancloud: failed to delate TXT record: id=%s: %w", recordID, err) return fmt.Errorf("arvancloud: failed to delate TXT record: id=%s: %w", recordID, err)
} }
@ -171,12 +176,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil return nil
} }
func getZone(fqdn string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
}
return dns01.UnFqdn(authZone), nil
}

View file

@ -2,39 +2,45 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
// defaultBaseURL represents the API endpoint to call. // defaultBaseURL represents the API endpoint to call.
const defaultBaseURL = "https://napi.arvancloud.ir" const defaultBaseURL = "https://napi.arvancloud.ir"
const authHeader = "Authorization" const authorizationHeader = "Authorization"
// Client the ArvanCloud client. // Client the ArvanCloud client.
type Client struct { type Client struct {
HTTPClient *http.Client
BaseURL string
apiKey string apiKey string
baseURL *url.URL
HTTPClient *http.Client
} }
// NewClient Creates a new ArvanCloud client. // NewClient Creates a new Client.
func NewClient(apiKey string) *Client { func NewClient(apiKey string) *Client {
baseURL, _ := url.Parse(defaultBaseURL)
return &Client{ return &Client{
HTTPClient: http.DefaultClient,
BaseURL: defaultBaseURL,
apiKey: apiKey, apiKey: apiKey,
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
} }
} }
// GetTxtRecord gets a TXT record. // GetTxtRecord gets a TXT record.
func (c *Client) GetTxtRecord(domain, name, value string) (*DNSRecord, error) { func (c *Client) GetTxtRecord(ctx context.Context, domain, name, value string) (*DNSRecord, error) {
records, err := c.getRecords(domain, name) records, err := c.getRecords(ctx, domain, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -49,11 +55,8 @@ func (c *Client) GetTxtRecord(domain, name, value string) (*DNSRecord, error) {
} }
// https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.list // https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.list
func (c *Client) getRecords(domain, search string) ([]DNSRecord, error) { func (c *Client) getRecords(ctx context.Context, domain, search string) ([]DNSRecord, error) {
endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records") endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records")
if err != nil {
return nil, fmt.Errorf("failed to create endpoint: %w", err)
}
if search != "" { if search != "" {
query := endpoint.Query() query := endpoint.Query()
@ -61,123 +64,110 @@ func (c *Client) getRecords(domain, search string) ([]DNSRecord, error) {
endpoint.RawQuery = query.Encode() endpoint.RawQuery = query.Encode()
} }
resp, err := c.do(http.MethodGet, endpoint.String(), nil) req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = resp.Body.Close() }() response := &apiResponse[[]DNSRecord]{}
err = c.do(req, http.StatusOK, response)
body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err) return nil, fmt.Errorf("could not get records %s: Domain: %s: %w", search, domain, err)
} }
if resp.StatusCode != http.StatusOK { return response.Data, nil
return nil, fmt.Errorf("could not get records %s: Domain: %s; Status: %s; Body: %s",
search, domain, resp.Status, string(body))
}
response := &apiResponse{}
err = json.Unmarshal(body, response)
if err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
var records []DNSRecord
err = json.Unmarshal(response.Data, &records)
if err != nil {
return nil, fmt.Errorf("failed to decode records: %w", err)
}
return records, nil
} }
// CreateRecord creates a DNS record. // CreateRecord creates a DNS record.
// https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.create // https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.create
func (c *Client) CreateRecord(domain string, record DNSRecord) (*DNSRecord, error) { func (c *Client) CreateRecord(ctx context.Context, domain string, record DNSRecord) (*DNSRecord, error) {
reqBody, err := json.Marshal(record) endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records") response := &apiResponse[*DNSRecord]{}
err = c.do(req, http.StatusCreated, response)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create endpoint: %w", err) return nil, fmt.Errorf("could not create record; Domain: %s: %w", domain, err)
} }
resp, err := c.do(http.MethodPost, endpoint.String(), bytes.NewReader(reqBody)) return response.Data, nil
if err != nil {
return nil, err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("could not create record %s; Domain: %s; Status: %s; Body: %s", string(reqBody), domain, resp.Status, string(body))
}
response := &apiResponse{}
err = json.Unmarshal(body, response)
if err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
var newRecord DNSRecord
err = json.Unmarshal(response.Data, &newRecord)
if err != nil {
return nil, fmt.Errorf("failed to decode record: %w", err)
}
return &newRecord, nil
} }
// DeleteRecord deletes a DNS record. // DeleteRecord deletes a DNS record.
// https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.remove // https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.remove
func (c *Client) DeleteRecord(domain, id string) error { func (c *Client) DeleteRecord(ctx context.Context, domain, id string) error {
endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records", id) endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records", id)
if err != nil {
return fmt.Errorf("failed to create endpoint: %w", err)
}
resp, err := c.do(http.MethodDelete, endpoint.String(), nil) req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil { if err != nil {
return err return err
} }
if resp.StatusCode != http.StatusOK { err = c.do(req, http.StatusOK, nil)
body, _ := io.ReadAll(resp.Body) if err != nil {
return fmt.Errorf("could not delete record %s; Domain: %s; Status: %s; Body: %s", id, domain, resp.Status, string(body)) return fmt.Errorf("could not delete record %s; Domain: %s: %w", id, domain, err)
} }
return nil return nil
} }
func (c *Client) do(method, endpoint string, body io.Reader) (*http.Response, error) { func (c *Client) do(req *http.Request, expectedStatus int, result any) error {
req, err := http.NewRequest(method, endpoint, body) req.Header.Set(authorizationHeader, c.apiKey)
resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != expectedStatus {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
} }
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
if body != nil {
if payload != nil {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
} }
req.Header.Set(authHeader, c.apiKey)
return c.HTTPClient.Do(req) return req, nil
}
func (c *Client) createEndpoint(parts ...string) (*url.URL, error) {
baseURL, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
}
return baseURL.JoinPath(parts...), nil
} }
func equalsTXTRecord(record DNSRecord, name, value string) bool { func equalsTXTRecord(record DNSRecord, name, value string) bool {
@ -189,7 +179,7 @@ func equalsTXTRecord(record DNSRecord, name, value string) bool {
return false return false
} }
data, ok := record.Value.(map[string]interface{}) data, ok := record.Value.(map[string]any)
if !ok { if !ok {
return false return false
} }

View file

@ -1,10 +1,12 @@
package internal package internal
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"os" "os"
"testing" "testing"
@ -12,21 +14,34 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestClient_GetTxtRecord(t *testing.T) { func setupTest(t *testing.T, apiKey string) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux() mux := http.NewServeMux()
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
t.Cleanup(server.Close) t.Cleanup(server.Close)
const domain = "example.com" client := NewClient(apiKey)
client.baseURL, _ = url.Parse(server.URL)
client.HTTPClient = server.Client()
return client, mux
}
func TestClient_GetTxtRecord(t *testing.T) {
const apiKey = "myKeyA" const apiKey = "myKeyA"
client, mux := setupTest(t, apiKey)
const domain = "example.com"
mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed)
return return
} }
auth := req.Header.Get(authHeader) auth := req.Header.Get(authorizationHeader)
if auth != apiKey { if auth != apiKey {
http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized) http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized)
return return
@ -46,20 +61,16 @@ func TestClient_GetTxtRecord(t *testing.T) {
} }
}) })
client := NewClient(apiKey) _, err := client.GetTxtRecord(context.Background(), domain, "_acme-challenge", "txtxtxt")
client.BaseURL = server.URL
_, err := client.GetTxtRecord(domain, "_acme-challenge", "txtxtxt")
require.NoError(t, err) require.NoError(t, err)
} }
func TestClient_CreateRecord(t *testing.T) { func TestClient_CreateRecord(t *testing.T) {
mux := http.NewServeMux() const apiKey = "myKeyB"
server := httptest.NewServer(mux)
t.Cleanup(server.Close) client, mux := setupTest(t, apiKey)
const domain = "example.com" const domain = "example.com"
const apiKey = "myKeyB"
mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost { if req.Method != http.MethodPost {
@ -67,7 +78,7 @@ func TestClient_CreateRecord(t *testing.T) {
return return
} }
auth := req.Header.Get(authHeader) auth := req.Header.Get(authorizationHeader)
if auth != apiKey { if auth != apiKey {
http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized) http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized)
return return
@ -88,9 +99,6 @@ func TestClient_CreateRecord(t *testing.T) {
} }
}) })
client := NewClient(apiKey)
client.BaseURL = server.URL
record := DNSRecord{ record := DNSRecord{
Name: "_acme-challenge", Name: "_acme-challenge",
Type: "txt", Type: "txt",
@ -98,7 +106,7 @@ func TestClient_CreateRecord(t *testing.T) {
TTL: 600, TTL: 600,
} }
newRecord, err := client.CreateRecord(domain, record) newRecord, err := client.CreateRecord(context.Background(), domain, record)
require.NoError(t, err) require.NoError(t, err)
expected := &DNSRecord{ expected := &DNSRecord{
@ -119,12 +127,11 @@ func TestClient_CreateRecord(t *testing.T) {
} }
func TestClient_DeleteRecord(t *testing.T) { func TestClient_DeleteRecord(t *testing.T) {
mux := http.NewServeMux() const apiKey = "myKeyC"
server := httptest.NewServer(mux)
t.Cleanup(server.Close) client, mux := setupTest(t, apiKey)
const domain = "example.com" const domain = "example.com"
const apiKey = "myKeyC"
const recordID = "recordId" const recordID = "recordId"
mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records/"+recordID, func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records/"+recordID, func(rw http.ResponseWriter, req *http.Request) {
@ -133,16 +140,13 @@ func TestClient_DeleteRecord(t *testing.T) {
return return
} }
auth := req.Header.Get(authHeader) auth := req.Header.Get(authorizationHeader)
if auth != apiKey { if auth != apiKey {
http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized) http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized)
return return
} }
}) })
client := NewClient(apiKey) err := client.DeleteRecord(context.Background(), domain, recordID)
client.BaseURL = server.URL
err := client.DeleteRecord(domain, recordID)
require.NoError(t, err) require.NoError(t, err)
} }

View file

@ -1,17 +1,15 @@
package internal package internal
import "encoding/json" type apiResponse[T any] struct {
type apiResponse struct {
Message string `json:"message"` Message string `json:"message"`
Data json.RawMessage `json:"data"` Data T `json:"data"`
} }
// DNSRecord a DNS record. // DNSRecord a DNS record.
type DNSRecord struct { type DNSRecord struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Value interface{} `json:"value,omitempty"` Value any `json:"value,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
TTL int `json:"ttl,omitempty"` TTL int `json:"ttl,omitempty"`
UpstreamHTTPS string `json:"upstream_https,omitempty"` UpstreamHTTPS string `json:"upstream_https,omitempty"`

View file

@ -108,7 +108,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("aurora: could not determine zone for domain %q: %w", domain, err) return fmt.Errorf("aurora: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
// 1. Aurora will happily create the TXT record when it is provided a fqdn, // 1. Aurora will happily create the TXT record when it is provided a fqdn,
@ -155,24 +155,24 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
d.recordIDsMu.Unlock() d.recordIDsMu.Unlock()
if !ok { if !ok {
return fmt.Errorf("unknown recordID for %q", info.EffectiveFQDN) return fmt.Errorf("aurora: unknown recordID for %q", info.EffectiveFQDN)
} }
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN)) authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN))
if err != nil { if err != nil {
return fmt.Errorf("could not determine zone for domain %q: %w", domain, err) return fmt.Errorf("aurora: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
authZone = dns01.UnFqdn(authZone) authZone = dns01.UnFqdn(authZone)
zone, err := d.getZoneInformationByName(authZone) zone, err := d.getZoneInformationByName(authZone)
if err != nil { if err != nil {
return err return fmt.Errorf("aurora: %w", err)
} }
_, _, err = d.client.DeleteRecord(zone.ID, recordID) _, _, err = d.client.DeleteRecord(zone.ID, recordID)
if err != nil { if err != nil {
return err return fmt.Errorf("aurora: %w", err)
} }
d.recordIDsMu.Lock() d.recordIDsMu.Lock()

View file

@ -2,6 +2,7 @@
package autodns package autodns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -10,6 +11,7 @@ import (
"github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env" "github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/autodns/internal"
) )
// Environment variables names. // Environment variables names.
@ -27,11 +29,6 @@ const (
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
) )
const (
defaultEndpointContext int = 4
defaultTTL int = 600
)
// Config is used to configure the creation of the DNSProvider. // Config is used to configure the creation of the DNSProvider.
type Config struct { type Config struct {
Endpoint *url.URL Endpoint *url.URL
@ -46,12 +43,12 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider. // NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config { func NewDefaultConfig() *Config {
endpoint, _ := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, defaultEndpoint)) endpoint, _ := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, internal.DefaultEndpoint))
return &Config{ return &Config{
Endpoint: endpoint, Endpoint: endpoint,
Context: env.GetOrDefaultInt(EnvAPIEndpointContext, defaultEndpointContext), Context: env.GetOrDefaultInt(EnvAPIEndpointContext, internal.DefaultEndpointContext),
TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), TTL: env.GetOrDefaultInt(EnvTTL, 600),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 2*time.Minute), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 2*time.Minute),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 2*time.Second), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 2*time.Second),
HTTPClient: &http.Client{ HTTPClient: &http.Client{
@ -63,6 +60,7 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface. // DNSProvider implements the challenge.Provider interface.
type DNSProvider struct { type DNSProvider struct {
config *Config config *Config
client *internal.Client
} }
// NewDNSProvider returns a DNSProvider instance configured for autoDNS. // NewDNSProvider returns a DNSProvider instance configured for autoDNS.
@ -94,7 +92,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("autodns: missing password") return nil, errors.New("autodns: missing password")
} }
return &DNSProvider{config: config}, nil client := internal.NewClient(config.Username, config.Password, config.Context)
if config.Endpoint != nil {
client.BaseURL = config.Endpoint
}
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
}
return &DNSProvider{config: config, client: client}, nil
} }
// Timeout returns the timeout and interval to use when checking for DNS propagation. // Timeout returns the timeout and interval to use when checking for DNS propagation.
@ -107,7 +115,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
records := []*ResourceRecord{{ records := []*internal.ResourceRecord{{
Name: info.EffectiveFQDN, Name: info.EffectiveFQDN,
TTL: int64(d.config.TTL), TTL: int64(d.config.TTL),
Type: "TXT", Type: "TXT",
@ -115,7 +123,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
}} }}
// TODO(ldez) replace domain by FQDN to follow CNAME. // TODO(ldez) replace domain by FQDN to follow CNAME.
_, err := d.addTxtRecord(domain, records) _, err := d.client.AddTxtRecords(context.Background(), domain, records)
if err != nil { if err != nil {
return fmt.Errorf("autodns: %w", err) return fmt.Errorf("autodns: %w", err)
} }
@ -127,7 +135,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
records := []*ResourceRecord{{ records := []*internal.ResourceRecord{{
Name: info.EffectiveFQDN, Name: info.EffectiveFQDN,
TTL: int64(d.config.TTL), TTL: int64(d.config.TTL),
Type: "TXT", Type: "TXT",
@ -135,7 +143,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
}} }}
// TODO(ldez) replace domain by FQDN to follow CNAME. // TODO(ldez) replace domain by FQDN to follow CNAME.
if err := d.removeTXTRecord(domain, records); err != nil { if err := d.client.RemoveTXTRecords(context.Background(), domain, records); err != nil {
return fmt.Errorf("autodns: %w", err) return fmt.Errorf("autodns: %w", err)
} }

View file

@ -1,159 +0,0 @@
package autodns
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
)
const (
defaultEndpoint = "https://api.autodns.com/v1/"
)
type ResponseMessage struct {
Text string `json:"text"`
Messages []string `json:"messages"`
Objects []string `json:"objects"`
Code string `json:"code"`
Status string `json:"status"`
}
type ResponseStatus struct {
Code string `json:"code"`
Text string `json:"text"`
Type string `json:"type"`
}
type ResponseObject struct {
Type string `json:"type"`
Value string `json:"value"`
Summary int32 `json:"summary"`
Data string
}
type DataZoneResponse struct {
STID string `json:"stid"`
CTID string `json:"ctid"`
Messages []*ResponseMessage `json:"messages"`
Status *ResponseStatus `json:"status"`
Object interface{} `json:"object"`
Data []*Zone `json:"data"`
}
// ResourceRecord holds a resource record.
type ResourceRecord struct {
Name string `json:"name"`
TTL int64 `json:"ttl"`
Type string `json:"type"`
Value string `json:"value"`
Pref int32 `json:"pref,omitempty"`
}
// Zone is an autodns zone record with all for us relevant fields.
type Zone struct {
Name string `json:"origin"`
ResourceRecords []*ResourceRecord `json:"resourceRecords"`
Action string `json:"action"`
VirtualNameServer string `json:"virtualNameServer"`
}
type ZoneStream struct {
Adds []*ResourceRecord `json:"adds"`
Removes []*ResourceRecord `json:"rems"`
}
func (d *DNSProvider) addTxtRecord(domain string, records []*ResourceRecord) (*Zone, error) {
zoneStream := &ZoneStream{Adds: records}
return d.makeZoneUpdateRequest(zoneStream, domain)
}
func (d *DNSProvider) removeTXTRecord(domain string, records []*ResourceRecord) error {
zoneStream := &ZoneStream{Removes: records}
_, err := d.makeZoneUpdateRequest(zoneStream, domain)
return err
}
func (d *DNSProvider) makeZoneUpdateRequest(zoneStream *ZoneStream, domain string) (*Zone, error) {
reqBody := &bytes.Buffer{}
if err := json.NewEncoder(reqBody).Encode(zoneStream); err != nil {
return nil, err
}
endpoint := d.config.Endpoint.JoinPath("zone", domain, "_stream")
req, err := d.makeRequest(http.MethodPost, endpoint.String(), reqBody)
if err != nil {
return nil, err
}
var resp *Zone
if err := d.sendRequest(req, &resp); err != nil {
return nil, err
}
return resp, nil
}
func (d *DNSProvider) makeRequest(method, endpoint string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, endpoint, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Domainrobot-Context", strconv.Itoa(d.config.Context))
req.SetBasicAuth(d.config.Username, d.config.Password)
return req, nil
}
func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error {
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
if err = checkResponse(resp); err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(raw, result)
if err != nil {
return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw))
}
return err
}
func checkResponse(resp *http.Response) error {
if resp.StatusCode < http.StatusBadRequest {
return nil
}
if resp.Body == nil {
return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err)
}
return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw))
}

View file

@ -0,0 +1,132 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
// DefaultEndpoint default API endpoint.
const DefaultEndpoint = "https://api.autodns.com/v1/"
// DefaultEndpointContext default API endpoint context.
const DefaultEndpointContext int = 4
// Client the Autodns API client.
type Client struct {
username string
password string
context int
BaseURL *url.URL
HTTPClient *http.Client
}
// NewClient creates a new Client.
func NewClient(username string, password string, clientContext int) *Client {
baseURL, _ := url.Parse(DefaultEndpoint)
return &Client{
username: username,
password: password,
context: clientContext,
BaseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}
}
// AddTxtRecords adds TXT records.
func (c *Client) AddTxtRecords(ctx context.Context, domain string, records []*ResourceRecord) (*Zone, error) {
zoneStream := &ZoneStream{Adds: records}
return c.updateZone(ctx, domain, zoneStream)
}
// RemoveTXTRecords removes TXT records.
func (c *Client) RemoveTXTRecords(ctx context.Context, domain string, records []*ResourceRecord) error {
zoneStream := &ZoneStream{Removes: records}
_, err := c.updateZone(ctx, domain, zoneStream)
return err
}
// https://github.com/InterNetX/domainrobot-api/blob/bdc8fe92a2f32fcbdb29e30bf6006ab446f81223/src/domainrobot.json#L21090
func (c *Client) updateZone(ctx context.Context, domain string, zoneStream *ZoneStream) (*Zone, error) {
endpoint := c.BaseURL.JoinPath("zone", domain, "_stream")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, zoneStream)
if err != nil {
return nil, err
}
var zone *Zone
if err := c.do(req, &zone); err != nil {
return nil, err
}
return zone, nil
}
func (c *Client) do(req *http.Request, result any) error {
req.Header.Set("X-Domainrobot-Context", strconv.Itoa(c.context))
req.SetBasicAuth(c.username, c.password)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View file

@ -0,0 +1,96 @@
package internal
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, method, pattern string, status int, file string) *Client {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest)
return
}
apiUser, apiKey, ok := req.BasicAuth()
if apiUser != "user" || apiKey != "secret" || !ok {
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if file == "" {
rw.WriteHeader(status)
return
}
open, err := os.Open(filepath.Join("fixtures", file))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = open.Close() }()
rw.WriteHeader(status)
_, err = io.Copy(rw, open)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
client := NewClient("user", "secret", 123)
client.HTTPClient = server.Client()
client.BaseURL, _ = url.Parse(server.URL)
return client
}
func TestClient_AddTxtRecords(t *testing.T) {
client := setupTest(t, http.MethodPost, "/zone/example.com/_stream", http.StatusOK, "add-record.json")
records := []*ResourceRecord{{}}
zone, err := client.AddTxtRecords(context.Background(), "example.com", records)
require.NoError(t, err)
expected := &Zone{
Name: "example.com",
ResourceRecords: []*ResourceRecord{{
Name: "example.com",
TTL: 120,
Type: "TXT",
Value: "txt",
Pref: 1,
}},
Action: "xxx",
VirtualNameServer: "yyy",
}
assert.Equal(t, expected, zone)
}
func TestClient_RemoveTXTRecords(t *testing.T) {
client := setupTest(t, http.MethodPost, "/zone/example.com/_stream", http.StatusOK, "add-record.json")
records := []*ResourceRecord{{}}
err := client.RemoveTXTRecords(context.Background(), "example.com", records)
require.NoError(t, err)
}

View file

@ -0,0 +1,14 @@
{
"origin": "example.com",
"resourceRecords": [
{
"name": "example.com",
"ttl": 120,
"type": "TXT",
"value": "txt",
"pref": 1
}
],
"action": "xxx",
"virtualNameServer": "yyy"
}

View file

@ -0,0 +1,14 @@
{
"origin": "example.com",
"resourceRecords": [
{
"name": "example.com",
"ttl": 120,
"type": "TXT",
"value": "txt",
"pref": 1
}
],
"action": "xxx",
"virtualNameServer": "yyy"
}

View file

@ -0,0 +1,57 @@
package internal
type ResponseMessage struct {
Text string `json:"text"`
Messages []string `json:"messages"`
Objects []string `json:"objects"`
Code string `json:"code"`
Status string `json:"status"`
}
type ResponseStatus struct {
Code string `json:"code"`
Text string `json:"text"`
Type string `json:"type"`
}
type ResponseObject struct {
Type string `json:"type"`
Value string `json:"value"`
Summary int32 `json:"summary"`
Data string
}
type DataZoneResponse struct {
STID string `json:"stid"`
CTID string `json:"ctid"`
Messages []*ResponseMessage `json:"messages"`
Status *ResponseStatus `json:"status"`
Object any `json:"object"`
Data []*Zone `json:"data"`
}
// ResourceRecord holds a resource record.
// https://help.internetx.com/display/APIXMLEN/Resource+Record+Object
type ResourceRecord struct {
Name string `json:"name"`
TTL int64 `json:"ttl"`
Type string `json:"type"`
Value string `json:"value"`
Pref int32 `json:"pref,omitempty"`
}
// Zone is an autodns zone record with all for us relevant fields.
// https://help.internetx.com/display/APIXMLEN/Zone+Object
type Zone struct {
Name string `json:"origin"`
ResourceRecords []*ResourceRecord `json:"resourceRecords"`
Action string `json:"action"`
VirtualNameServer string `json:"virtualNameServer"`
}
// ZoneStream body of the requests.
// https://github.com/InterNetX/domainrobot-api/blob/bdc8fe92a2f32fcbdb29e30bf6006ab446f81223/src/domainrobot.json#L35914-L35932
type ZoneStream struct {
Adds []*ResourceRecord `json:"adds"`
Removes []*ResourceRecord `json:"rems"`
}

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest"
@ -14,6 +15,7 @@ import (
"github.com/Azure/go-autorest/autorest/azure/auth" "github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/platform/config/env" "github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const defaultMetadataEndpoint = "http://169.254.169.254" const defaultMetadataEndpoint = "http://169.254.169.254"
@ -122,7 +124,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
} }
if config.HTTPClient == nil { if config.HTTPClient == nil {
config.HTTPClient = http.DefaultClient config.HTTPClient = &http.Client{Timeout: 5 * time.Second}
} }
authorizer, err := getAuthorizer(config) authorizer, err := getAuthorizer(config)
@ -208,8 +210,12 @@ func getMetadata(config *Config, field string) (string, error) {
metadataEndpoint = defaultMetadataEndpoint metadataEndpoint = defaultMetadataEndpoint
} }
resource := fmt.Sprintf("%s/metadata/instance/compute/%s", metadataEndpoint, field) endpoint, err := url.JoinPath(metadataEndpoint, "metadata", "instance", "compute", field)
req, err := http.NewRequest(http.MethodGet, resource, nil) if err != nil {
return "", err
}
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -223,14 +229,15 @@ func getMetadata(config *Config, field string) (string, error) {
resp, err := config.HTTPClient.Do(req) resp, err := config.HTTPClient.Do(req)
if err != nil { if err != nil {
return "", err return "", errutils.NewHTTPDoError(req, err)
} }
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body) defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", err return "", errutils.NewReadResponseError(req, resp.StatusCode, err)
} }
return string(respBody), nil return string(raw), nil
} }

View file

@ -118,7 +118,7 @@ func (d *dnsProviderPrivate) getHostedZoneID(ctx context.Context, fqdn string) (
authZone, err := dns01.FindZoneByFqdn(fqdn) authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err)
} }
dc := privatedns.NewPrivateZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) dc := privatedns.NewPrivateZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)

View file

@ -118,7 +118,7 @@ func (d *dnsProviderPublic) getHostedZoneID(ctx context.Context, fqdn string) (s
authZone, err := dns01.FindZoneByFqdn(fqdn) authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err)
} }
dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)

View file

@ -2,6 +2,7 @@
package bluecat package bluecat
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -97,7 +98,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("bluecat: credentials missing") return nil, errors.New("bluecat: credentials missing")
} }
client := internal.NewClient(config.BaseURL) client := internal.NewClient(config.BaseURL, config.UserName, config.Password)
if config.HTTPClient != nil { if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient client.HTTPClient = config.HTTPClient
@ -112,17 +113,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Login(d.config.UserName, d.config.Password) ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil { if err != nil {
return fmt.Errorf("bluecat: login: %w", err) return fmt.Errorf("bluecat: login: %w", err)
} }
viewID, err := d.client.LookupViewID(d.config.ConfigName, d.config.DNSView) viewID, err := d.client.LookupViewID(ctx, d.config.ConfigName, d.config.DNSView)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: lookupViewID: %w", err) return fmt.Errorf("bluecat: lookupViewID: %w", err)
} }
parentZoneID, name, err := d.client.LookupParentZoneID(viewID, info.EffectiveFQDN) parentZoneID, name, err := d.client.LookupParentZoneID(ctx, viewID, info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: lookupParentZoneID: %w", err) return fmt.Errorf("bluecat: lookupParentZoneID: %w", err)
} }
@ -137,17 +138,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
Properties: fmt.Sprintf("ttl=%d|absoluteName=%s|txt=%s|", d.config.TTL, info.EffectiveFQDN, info.Value), Properties: fmt.Sprintf("ttl=%d|absoluteName=%s|txt=%s|", d.config.TTL, info.EffectiveFQDN, info.Value),
} }
_, err = d.client.AddEntity(parentZoneID, txtRecord) _, err = d.client.AddEntity(ctx, parentZoneID, txtRecord)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: add TXT record: %w", err) return fmt.Errorf("bluecat: add TXT record: %w", err)
} }
err = d.client.Deploy(parentZoneID) err = d.client.Deploy(ctx, parentZoneID)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: deploy: %w", err) return fmt.Errorf("bluecat: deploy: %w", err)
} }
err = d.client.Logout() err = d.client.Logout(ctx)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: logout: %w", err) return fmt.Errorf("bluecat: logout: %w", err)
} }
@ -159,37 +160,37 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Login(d.config.UserName, d.config.Password) ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil { if err != nil {
return fmt.Errorf("bluecat: login: %w", err) return fmt.Errorf("bluecat: login: %w", err)
} }
viewID, err := d.client.LookupViewID(d.config.ConfigName, d.config.DNSView) viewID, err := d.client.LookupViewID(ctx, d.config.ConfigName, d.config.DNSView)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: lookupViewID: %w", err) return fmt.Errorf("bluecat: lookupViewID: %w", err)
} }
parentZoneID, name, err := d.client.LookupParentZoneID(viewID, info.EffectiveFQDN) parentZoneID, name, err := d.client.LookupParentZoneID(ctx, viewID, info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: lookupParentZoneID: %w", err) return fmt.Errorf("bluecat: lookupParentZoneID: %w", err)
} }
txtRecord, err := d.client.GetEntityByName(parentZoneID, name, internal.TXTType) txtRecord, err := d.client.GetEntityByName(ctx, parentZoneID, name, internal.TXTType)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: get TXT record: %w", err) return fmt.Errorf("bluecat: get TXT record: %w", err)
} }
err = d.client.Delete(txtRecord.ID) err = d.client.Delete(ctx, txtRecord.ID)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: delete TXT record: %w", err) return fmt.Errorf("bluecat: delete TXT record: %w", err)
} }
err = d.client.Deploy(parentZoneID) err = d.client.Deploy(ctx, parentZoneID)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: deploy: %w", err) return fmt.Errorf("bluecat: deploy: %w", err)
} }
err = d.client.Logout() err = d.client.Logout(ctx)
if err != nil { if err != nil {
return fmt.Errorf("bluecat: logout: %w", err) return fmt.Errorf("bluecat: logout: %w", err)
} }

View file

@ -2,14 +2,18 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
// Object types. // Object types.
@ -20,153 +24,88 @@ const (
TXTType = "TXTRecord" TXTType = "TXTRecord"
) )
const authorizationHeader = "Authorization"
type Client struct { type Client struct {
HTTPClient *http.Client username string
password string
baseURL string
token string
tokenExp *regexp.Regexp tokenExp *regexp.Regexp
baseURL *url.URL
HTTPClient *http.Client
} }
func NewClient(baseURL string) *Client { func NewClient(baseURL string, username, password string) *Client {
bu, _ := url.Parse(baseURL)
return &Client{ return &Client{
HTTPClient: &http.Client{Timeout: 30 * time.Second}, username: username,
baseURL: baseURL, password: password,
tokenExp: regexp.MustCompile("BAMAuthToken: [^ ]+"), tokenExp: regexp.MustCompile("BAMAuthToken: [^ ]+"),
baseURL: bu,
HTTPClient: &http.Client{Timeout: 30 * time.Second},
} }
} }
// Login Logs in as API user.
// Authenticates and receives a token to be used in for subsequent requests.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/login/9.1.0
func (c *Client) Login(username, password string) error {
queryArgs := map[string]string{
"username": username,
"password": password,
}
resp, err := c.sendRequest(http.MethodGet, "login", nil, queryArgs)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "login",
Message: string(data),
}
}
authBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
authResp := string(authBytes)
if strings.Contains(authResp, "Authentication Error") {
return fmt.Errorf("request failed: %s", strings.Trim(authResp, `"`))
}
// Upon success, API responds with "Session Token-> BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM= <- for User : username"
c.token = c.tokenExp.FindString(authResp)
return nil
}
// Logout Logs out of the current API session.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/logout/9.1.0
func (c *Client) Logout() error {
if c.token == "" {
// nothing to do
return nil
}
resp, err := c.sendRequest(http.MethodGet, "logout", nil, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "logout",
Message: string(data),
}
}
authBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
authResp := string(authBytes)
if !strings.Contains(authResp, "successfully") {
return fmt.Errorf("request failed to delete session: %s", strings.Trim(authResp, `"`))
}
c.token = ""
return nil
}
// Deploy the DNS config for the specified entity to the authoritative servers. // Deploy the DNS config for the specified entity to the authoritative servers.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/POST/v1/quickDeploy/9.1.0 // https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/POST/v1/quickDeploy/9.5.0
func (c *Client) Deploy(entityID uint) error { func (c *Client) Deploy(ctx context.Context, entityID uint) error {
queryArgs := map[string]string{ endpoint := c.createEndpoint("quickDeploy")
"entityId": strconv.FormatUint(uint64(entityID), 10),
}
resp, err := c.sendRequest(http.MethodPost, "quickDeploy", nil, queryArgs) q := endpoint.Query()
q.Set("entityId", strconv.FormatUint(uint64(entityID), 10))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, nil)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close()
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
// The API doc says that 201 is expected but in the reality 200 is return. // The API doc says that 201 is expected but in the reality 200 is return.
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body) return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "quickDeploy",
Message: string(data),
}
} }
return nil return nil
} }
// AddEntity A generic method for adding configurations, DNS zones, and DNS resource records. // AddEntity A generic method for adding configurations, DNS zones, and DNS resource records.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/POST/v1/addEntity/9.1.0 // https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/POST/v1/addEntity/9.5.0
func (c *Client) AddEntity(parentID uint, entity Entity) (uint64, error) { func (c *Client) AddEntity(ctx context.Context, parentID uint, entity Entity) (uint64, error) {
queryArgs := map[string]string{ endpoint := c.createEndpoint("addEntity")
"parentId": strconv.FormatUint(uint64(parentID), 10),
}
resp, err := c.sendRequest(http.MethodPost, "addEntity", entity, queryArgs) q := endpoint.Query()
q.Set("parentId", strconv.FormatUint(uint64(parentID), 10))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, entity)
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer resp.Body.Close()
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return 0, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body) return 0, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
return 0, &APIError{
StatusCode: resp.StatusCode,
Resource: "addEntity",
Message: string(data),
}
} }
addTxtBytes, _ := io.ReadAll(resp.Body) raw, _ := io.ReadAll(resp.Body)
// addEntity responds only with body text containing the ID of the created record // addEntity responds only with body text containing the ID of the created record
addTxtResp := string(addTxtBytes) addTxtResp := string(raw)
id, err := strconv.ParseUint(addTxtResp, 10, 64) id, err := strconv.ParseUint(addTxtResp, 10, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("addEntity request failed: %s", addTxtResp) return 0, fmt.Errorf("addEntity request failed: %s", addTxtResp)
@ -176,73 +115,84 @@ func (c *Client) AddEntity(parentID uint, entity Entity) (uint64, error) {
} }
// GetEntityByName Returns objects from the database referenced by their database ID and with its properties fields populated. // GetEntityByName Returns objects from the database referenced by their database ID and with its properties fields populated.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/getEntityById/9.1.0 // https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/getEntityById/9.5.0
func (c *Client) GetEntityByName(parentID uint, name, objType string) (*EntityResponse, error) { func (c *Client) GetEntityByName(ctx context.Context, parentID uint, name, objType string) (*EntityResponse, error) {
queryArgs := map[string]string{ endpoint := c.createEndpoint("getEntityByName")
"parentId": strconv.FormatUint(uint64(parentID), 10),
"name": name,
"type": objType,
}
resp, err := c.sendRequest(http.MethodGet, "getEntityByName", nil, queryArgs) q := endpoint.Query()
q.Set("parentId", strconv.FormatUint(uint64(parentID), 10))
q.Set("name", name)
q.Set("type", objType)
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close()
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return nil, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body) return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
return nil, &APIError{
StatusCode: resp.StatusCode,
Resource: "getEntityByName",
Message: string(data),
}
} }
var txtRec EntityResponse raw, err := io.ReadAll(resp.Body)
if err = json.NewDecoder(resp.Body).Decode(&txtRec); err != nil { if err != nil {
return nil, fmt.Errorf("JSON decode: %w", err) return nil, errutils.NewReadResponseError(req, resp.StatusCode, err)
} }
return &txtRec, nil var entity EntityResponse
err = json.Unmarshal(raw, &entity)
if err != nil {
return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return &entity, nil
} }
// Delete Deletes an object using the generic delete method. // Delete Deletes an object using the generic delete method.
// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/DELETE/v1/delete/9.1.0 // https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/DELETE/v1/delete/9.5.0
func (c *Client) Delete(objectID uint) error { func (c *Client) Delete(ctx context.Context, objectID uint) error {
queryArgs := map[string]string{ endpoint := c.createEndpoint("delete")
"objectId": strconv.FormatUint(uint64(objectID), 10),
}
resp, err := c.sendRequest(http.MethodDelete, "delete", nil, queryArgs) q := endpoint.Query()
q.Set("objectId", strconv.FormatUint(uint64(objectID), 10))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() resp, err := c.doAuthenticated(ctx, req)
if err != nil {
// The API doc says that 204 is expected but in the reality 200 is return. return errutils.NewHTTPDoError(req, err)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
return &APIError{
StatusCode: resp.StatusCode,
Resource: "delete",
Message: string(data),
} }
defer func() { _ = resp.Body.Close() }()
// The API doc says that 204 is expected but in the reality 200 is returned.
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
} }
return nil return nil
} }
// LookupViewID Find the DNS view with the given name within. // LookupViewID Find the DNS view with the given name within.
func (c *Client) LookupViewID(configName, viewName string) (uint, error) { func (c *Client) LookupViewID(ctx context.Context, configName, viewName string) (uint, error) {
// Lookup the entity ID of the configuration named in our properties. // Lookup the entity ID of the configuration named in our properties.
conf, err := c.GetEntityByName(0, configName, ConfigType) conf, err := c.GetEntityByName(ctx, 0, configName, ConfigType)
if err != nil { if err != nil {
return 0, err return 0, err
} }
view, err := c.GetEntityByName(conf.ID, viewName, ViewType) view, err := c.GetEntityByName(ctx, conf.ID, viewName, ViewType)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -252,7 +202,7 @@ func (c *Client) LookupViewID(configName, viewName string) (uint, error) {
// LookupParentZoneID Return the entityId of the parent zone by recursing from the root view. // LookupParentZoneID Return the entityId of the parent zone by recursing from the root view.
// Also return the simple name of the host. // Also return the simple name of the host.
func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, error) { func (c *Client) LookupParentZoneID(ctx context.Context, viewID uint, fqdn string) (uint, string, error) {
if fqdn == "" { if fqdn == "" {
return viewID, "", nil return viewID, "", nil
} }
@ -263,7 +213,7 @@ func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, err
parentViewID := viewID parentViewID := viewID
for i := len(zones) - 1; i > -1; i-- { for i := len(zones) - 1; i > -1; i-- {
zone, err := c.GetEntityByName(parentViewID, zones[i], ZoneType) zone, err := c.GetEntityByName(ctx, parentViewID, zones[i], ZoneType)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("could not find zone named %s: %w", name, err) return 0, "", fmt.Errorf("could not find zone named %s: %w", name, err)
} }
@ -282,32 +232,39 @@ func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, err
return parentViewID, name, nil return parentViewID, name, nil
} }
// Send a REST request, using query parameters specified. func (c *Client) createEndpoint(resource string) *url.URL {
// The Authorization header will be set if we have an active auth token. return c.baseURL.JoinPath("Services", "REST", "v1", resource)
func (c *Client) sendRequest(method, resource string, payload interface{}, queryParams map[string]string) (*http.Response, error) { }
url := fmt.Sprintf("%s/Services/REST/v1/%s", c.baseURL, resource)
body, err := json.Marshal(payload) func (c *Client) doAuthenticated(ctx context.Context, req *http.Request) (*http.Response, error) {
if err != nil { tok := getToken(ctx)
return nil, err if tok != "" {
req.Header.Set(authorizationHeader, tok)
} }
req, err := http.NewRequest(method, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.token != "" {
req.Header.Set("Authorization", c.token)
}
q := req.URL.Query()
for k, v := range queryParams {
q.Set(k, v)
}
req.URL.RawQuery = q.Encode()
return c.HTTPClient.Do(req) return c.HTTPClient.Do(req)
} }
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -15,7 +16,8 @@ func TestClient_LookupParentZoneID(t *testing.T) {
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
t.Cleanup(server.Close) t.Cleanup(server.Close)
client := NewClient(server.URL) client := NewClient(server.URL, "user", "secret")
client.HTTPClient = server.Client()
mux.HandleFunc("/Services/REST/v1/getEntityByName", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/Services/REST/v1/getEntityByName", func(rw http.ResponseWriter, req *http.Request) {
query := req.URL.Query() query := req.URL.Query()
@ -33,7 +35,7 @@ func TestClient_LookupParentZoneID(t *testing.T) {
http.Error(rw, "{}", http.StatusOK) http.Error(rw, "{}", http.StatusOK)
}) })
parentID, name, err := client.LookupParentZoneID(2, "foo.example.com") parentID, name, err := client.LookupParentZoneID(context.Background(), 2, "foo.example.com")
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, 2, parentID) assert.EqualValues(t, 2, parentID)

View file

@ -0,0 +1,115 @@
package internal
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
type token string
const tokenKey token = "token"
// login Logs in as API user.
// Authenticates and receives a token to be used in for subsequent requests.
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/login/9.5.0
func (c *Client) login(ctx context.Context) (string, error) {
endpoint := c.createEndpoint("login")
q := endpoint.Query()
q.Set("username", c.username)
q.Set("password", c.password)
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return "", err
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return "", errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", errutils.NewReadResponseError(req, resp.StatusCode, err)
}
authResp := string(raw)
if strings.Contains(authResp, "Authentication Error") {
return "", fmt.Errorf("request failed: %s", strings.Trim(authResp, `"`))
}
// Upon success, API responds with "Session Token-> BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM= <- for User : username"
tok := c.tokenExp.FindString(authResp)
return tok, nil
}
// Logout Logs out of the current API session.
// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/logout/9.5.0
func (c *Client) Logout(ctx context.Context) error {
if getToken(ctx) == "" {
// nothing to do
return nil
}
endpoint := c.createEndpoint("logout")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := c.doAuthenticated(ctx, req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
authResp := string(raw)
if !strings.Contains(authResp, "successfully") {
return fmt.Errorf("request failed to delete session: %s", strings.Trim(authResp, `"`))
}
return nil
}
func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) {
tok, err := c.login(ctx)
if err != nil {
return nil, err
}
return context.WithValue(ctx, tokenKey, tok), nil
}
func getToken(ctx context.Context) string {
tok, ok := ctx.Value(tokenKey).(string)
if !ok {
return ""
}
return tok
}

View file

@ -0,0 +1,59 @@
package internal
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const fakeToken = "BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM="
func TestClient_CreateAuthenticatedContext(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient(server.URL, "user", "secret")
client.HTTPClient = server.Client()
mux.HandleFunc("/Services/REST/v1/login", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
query := req.URL.Query()
if query.Get("username") != "user" {
http.Error(rw, fmt.Sprintf("invalid username %s", query.Get("username")), http.StatusUnauthorized)
return
}
if query.Get("password") != "secret" {
http.Error(rw, fmt.Sprintf("invalid password %s", query.Get("password")), http.StatusUnauthorized)
return
}
_, _ = fmt.Fprint(rw, fakeToken)
})
mux.HandleFunc("/Services/REST/v1/delete", func(rw http.ResponseWriter, req *http.Request) {
authorization := req.Header.Get(authorizationHeader)
if authorization != fakeToken {
http.Error(rw, fmt.Sprintf("invalid credential: %s", authorization), http.StatusUnauthorized)
return
}
})
ctx, err := client.CreateAuthenticatedContext(context.Background())
require.NoError(t, err)
at := getToken(ctx)
assert.Equal(t, fakeToken, at)
err = client.Delete(ctx, 123)
require.NoError(t, err)
}

View file

@ -1,7 +1,5 @@
package internal package internal
import "fmt"
// Entity JSON body for Bluecat entity requests. // Entity JSON body for Bluecat entity requests.
type Entity struct { type Entity struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
@ -17,13 +15,3 @@ type EntityResponse struct {
Type string `json:"type"` Type string `json:"type"`
Properties string `json:"properties"` Properties string `json:"properties"`
} }
type APIError struct {
StatusCode int
Resource string
Message string
}
func (a APIError) Error() string {
return fmt.Sprintf("resource: %s, status code: %d, message: %s", a.Resource, a.StatusCode, a.Message)
}

View file

@ -1,9 +1,11 @@
package brandit package brandit
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"sync" "sync"
"time" "time"
@ -12,8 +14,6 @@ import (
"github.com/go-acme/lego/v4/providers/dns/brandit/internal" "github.com/go-acme/lego/v4/providers/dns/brandit/internal"
) )
const defaultTTL = 600
// Environment variables names. // Environment variables names.
const ( const (
envNamespace = "BRANDIT_" envNamespace = "BRANDIT_"
@ -25,7 +25,6 @@ const (
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT" EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
EnvPollingInterval = envNamespace + "POLLING_INTERVAL" EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
DefaultBrandItPropagationTimeout = 600 * time.Second
) )
// Config is used to configure the creation of the DNSProvider. // Config is used to configure the creation of the DNSProvider.
@ -42,8 +41,8 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider. // NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), TTL: env.GetOrDefaultInt(EnvTTL, 600),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, DefaultBrandItPropagationTimeout), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 10*time.Minute),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval),
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second), Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second),
@ -97,13 +96,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
}, nil }, nil
} }
// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}
// Present creates a TXT record using the specified parameters. // Present creates a TXT record using the specified parameters.
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("brandit: %w", err) return fmt.Errorf("brandit: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
@ -111,6 +116,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
return fmt.Errorf("brandit: %w", err) return fmt.Errorf("brandit: %w", err)
} }
ctx := context.Background()
record := internal.Record{ record := internal.Record{
Type: "TXT", Type: "TXT",
Name: subDomain, Name: subDomain,
@ -119,18 +126,18 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
} }
// find the account associated with the domain // find the account associated with the domain
account, err := d.client.StatusDomain(dns01.UnFqdn(authZone)) account, err := d.client.StatusDomain(ctx, dns01.UnFqdn(authZone))
if err != nil { if err != nil {
return fmt.Errorf("brandit: status domain: %w", err) return fmt.Errorf("brandit: status domain: %w", err)
} }
// Find the next record id // Find the next record id
recordID, err := d.client.ListRecords(account.Response.Registrar[0], dns01.UnFqdn(authZone)) recordID, err := d.client.ListRecords(ctx, account.Registrar[0], dns01.UnFqdn(authZone))
if err != nil { if err != nil {
return fmt.Errorf("brandit: list records: %w", err) return fmt.Errorf("brandit: list records: %w", err)
} }
result, err := d.client.AddRecord(dns01.UnFqdn(authZone), account.Response.Registrar[0], fmt.Sprint(recordID.Response.Total[0]), record) result, err := d.client.AddRecord(ctx, dns01.UnFqdn(authZone), account.Registrar[0], strconv.Itoa(recordID.Total[0]), record)
if err != nil { if err != nil {
return fmt.Errorf("brandit: add record: %w", err) return fmt.Errorf("brandit: add record: %w", err)
} }
@ -148,7 +155,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("brandit: %w", err) return fmt.Errorf("brandit: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
// gets the record's unique ID // gets the record's unique ID
@ -159,25 +166,27 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("brandit: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) return fmt.Errorf("brandit: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
} }
ctx := context.Background()
// find the account associated with the domain // find the account associated with the domain
account, err := d.client.StatusDomain(dns01.UnFqdn(authZone)) account, err := d.client.StatusDomain(ctx, dns01.UnFqdn(authZone))
if err != nil { if err != nil {
return fmt.Errorf("brandit: status domain: %w", err) return fmt.Errorf("brandit: status domain: %w", err)
} }
records, err := d.client.ListRecords(account.Response.Registrar[0], dns01.UnFqdn(authZone)) records, err := d.client.ListRecords(ctx, account.Registrar[0], dns01.UnFqdn(authZone))
if err != nil { if err != nil {
return fmt.Errorf("brandit: list records: %w", err) return fmt.Errorf("brandit: list records: %w", err)
} }
var recordID int var recordID int
for i, r := range records.Response.RR { for i, r := range records.RR {
if r == dnsRecord { if r == dnsRecord {
recordID = i recordID = i
} }
} }
_, err = d.client.DeleteRecord(dns01.UnFqdn(authZone), account.Response.Registrar[0], dnsRecord, fmt.Sprint(recordID)) err = d.client.DeleteRecord(ctx, dns01.UnFqdn(authZone), account.Registrar[0], dnsRecord, strconv.Itoa(recordID))
if err != nil { if err != nil {
return fmt.Errorf("brandit: delete record: %w", err) return fmt.Errorf("brandit: delete record: %w", err)
} }
@ -189,9 +198,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil return nil
} }
// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}

View file

@ -1,4 +1,4 @@
Name = "BRANDIT" Name = "Brandit"
Description = '''''' Description = ''''''
URL = "https://www.brandit.com/" URL = "https://www.brandit.com/"
Code = "brandit" Code = "brandit"

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
@ -12,6 +13,8 @@ import (
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const defaultBaseURL = "https://portal.brandit.com/api/v3/" const defaultBaseURL = "https://portal.brandit.com/api/v3/"
@ -20,7 +23,8 @@ const defaultBaseURL = "https://portal.brandit.com/api/v3/"
type Client struct { type Client struct {
apiUsername string apiUsername string
apiKey string apiKey string
BaseURL string
baseURL string
HTTPClient *http.Client HTTPClient *http.Client
} }
@ -33,70 +37,69 @@ func NewClient(apiUsername, apiKey string) (*Client, error) {
return &Client{ return &Client{
apiUsername: apiUsername, apiUsername: apiUsername,
apiKey: apiKey, apiKey: apiKey,
BaseURL: defaultBaseURL, baseURL: defaultBaseURL,
HTTPClient: &http.Client{Timeout: 10 * time.Second}, HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil }, nil
} }
// ListRecords lists all records. // ListRecords lists all records.
// https://portal.brandit.com/apidocv3#listDNSRR // https://portal.brandit.com/apidocv3#listDNSRR
func (c *Client) ListRecords(account, dnsZone string) (*ListRecords, error) { func (c *Client) ListRecords(ctx context.Context, account, dnsZone string) (*ListRecordsResponse, error) {
// Create a new query
query := url.Values{} query := url.Values{}
query.Add("command", "listDNSRR") query.Add("command", "listDNSRR")
query.Add("account", account) query.Add("account", account)
query.Add("dnszone", dnsZone) query.Add("dnszone", dnsZone)
result := &ListRecords{} result := &Response[*ListRecordsResponse]{}
err := c.do(query, result) err := c.do(ctx, query, result)
if err != nil { if err != nil {
return nil, fmt.Errorf("do: %w", err) return nil, err
} }
for len(result.Response.RR) < result.Response.Total[0] { for len(result.Response.RR) < result.Response.Total[0] {
query.Add("first", fmt.Sprint(result.Response.Last[0]+1)) query.Add("first", fmt.Sprint(result.Response.Last[0]+1))
tmp := &ListRecords{} tmp := &Response[*ListRecordsResponse]{}
err := c.do(query, tmp) err := c.do(ctx, query, tmp)
if err != nil { if err != nil {
return nil, fmt.Errorf("do: %w", err) return nil, err
} }
result.Response.RR = append(result.Response.RR, tmp.Response.RR...) result.Response.RR = append(result.Response.RR, tmp.Response.RR...)
result.Response.Last = tmp.Response.Last result.Response.Last = tmp.Response.Last
} }
return result, nil return result.Response, nil
} }
// AddRecord adds a DNS record. // AddRecord adds a DNS record.
// https://portal.brandit.com/apidocv3#addDNSRR // https://portal.brandit.com/apidocv3#addDNSRR
func (c *Client) AddRecord(domainName, account, newRecordID string, record Record) (*AddRecord, error) { func (c *Client) AddRecord(ctx context.Context, domainName, account, newRecordID string, record Record) (*AddRecord, error) {
// Create a new query value := strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ")
query := url.Values{} query := url.Values{}
query.Add("command", "addDNSRR") query.Add("command", "addDNSRR")
query.Add("account", account) query.Add("account", account)
query.Add("dnszone", domainName) query.Add("dnszone", domainName)
query.Add("rrdata", strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ")) query.Add("rrdata", value)
query.Add("key", newRecordID) query.Add("key", newRecordID)
result := &AddRecord{} result := &AddRecord{}
err := c.do(query, result) err := c.do(ctx, query, result)
if err != nil { if err != nil {
return nil, fmt.Errorf("do: %w", err) return nil, err
} }
result.Record = strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ")
result.Record = value
return result, nil return result, nil
} }
// DeleteRecord deletes a DNS record. // DeleteRecord deletes a DNS record.
// https://portal.brandit.com/apidocv3#deleteDNSRR // https://portal.brandit.com/apidocv3#deleteDNSRR
func (c *Client) DeleteRecord(domainName, account, dnsRecord, recordID string) (*DeleteRecord, error) { func (c *Client) DeleteRecord(ctx context.Context, domainName, account, dnsRecord, recordID string) error {
// Create a new query
query := url.Values{} query := url.Values{}
query.Add("command", "deleteDNSRR") query.Add("command", "deleteDNSRR")
query.Add("account", account) query.Add("account", account)
@ -104,68 +107,70 @@ func (c *Client) DeleteRecord(domainName, account, dnsRecord, recordID string) (
query.Add("rrdata", dnsRecord) query.Add("rrdata", dnsRecord)
query.Add("key", recordID) query.Add("key", recordID)
result := &DeleteRecord{} return c.do(ctx, query, nil)
err := c.do(query, result)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
}
return result, nil
} }
// StatusDomain returns the status of a domain and account associated with it. // StatusDomain returns the status of a domain and account associated with it.
// https://portal.brandit.com/apidocv3#statusDomain // https://portal.brandit.com/apidocv3#statusDomain
func (c *Client) StatusDomain(domain string) (*StatusDomain, error) { func (c *Client) StatusDomain(ctx context.Context, domain string) (*StatusResponse, error) {
// Create a new query
query := url.Values{} query := url.Values{}
query.Add("command", "statusDomain") query.Add("command", "statusDomain")
query.Add("domain", domain) query.Add("domain", domain)
result := &StatusDomain{} result := &Response[*StatusResponse]{}
err := c.do(query, result) err := c.do(ctx, query, result)
if err != nil { if err != nil {
return nil, fmt.Errorf("do: %w", err) return nil, err
} }
return result, nil return result.Response, nil
} }
func (c *Client) do(query url.Values, result any) error { func (c *Client) do(ctx context.Context, query url.Values, result any) error {
// Add signature values, err := sign(c.apiUsername, c.apiKey, query)
v, err := sign(c.apiUsername, c.apiKey, query)
if err != nil {
return fmt.Errorf("signature: %w", err)
}
resp, err := c.HTTPClient.PostForm(c.BaseURL, v)
if err != nil { if err != nil {
return err return err
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, strings.NewReader(values.Encode()))
if err != nil {
return fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body) raw, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("read response body: %w", err) return errutils.NewReadResponseError(req, resp.StatusCode, err)
} }
// Unmarshal the error response, because the API returns a 200 OK even if there is an error. // Unmarshal the error response, because the API returns a 200 OK even if there is an error.
var apiError APIError var apiError APIError
err = json.Unmarshal(raw, &apiError) err = json.Unmarshal(raw, &apiError)
if err != nil { if err != nil {
return fmt.Errorf("unmarshal error response: %w %s", err, string(raw)) return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
} }
if apiError.Code > 299 || apiError.Status != "success" { if apiError.Code > 299 || apiError.Status != "success" {
return apiError return apiError
} }
if result == nil {
return nil
}
err = json.Unmarshal(raw, result) err = json.Unmarshal(raw, result)
if err != nil { if err != nil {
return fmt.Errorf("unmarshal response body: %w %s", err, string(raw)) return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
} }
return nil return nil

View file

@ -1,30 +1,32 @@
package internal package internal
import ( import (
"context"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func setupTest(t *testing.T, file string) *Client { func setupTest(t *testing.T, filename string) *Client {
t.Helper() t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
open, err := os.Open(file) file, err := os.Open(filepath.Join("fixtures", filename))
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }
defer func() { _ = open.Close() }() defer func() { _ = file.Close() }()
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_, err = io.Copy(rw, open) _, err = io.Copy(rw, file)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
@ -36,19 +38,18 @@ func setupTest(t *testing.T, file string) *Client {
require.NoError(t, err) require.NoError(t, err)
client.HTTPClient = server.Client() client.HTTPClient = server.Client()
client.BaseURL = server.URL client.baseURL = server.URL
return client return client
} }
func TestClient_StatusDomain(t *testing.T) { func TestClient_StatusDomain(t *testing.T) {
client := setupTest(t, "./fixtures/status-domain.json") client := setupTest(t, "status-domain.json")
domain, err := client.StatusDomain("example.com") domain, err := client.StatusDomain(context.Background(), "example.com")
require.NoError(t, err) require.NoError(t, err)
expected := &StatusDomain{ expected := &StatusResponse{
Response: StatusResponse{
RenewalMode: []string{"DEFAULT"}, RenewalMode: []string{"DEFAULT"},
Status: []string{"clientTransferProhibited"}, Status: []string{"clientTransferProhibited"},
TransferLock: []int{1}, TransferLock: []int{1},
@ -73,23 +74,25 @@ func TestClient_StatusDomain(t *testing.T) {
OwnerContact: []string{"example"}, OwnerContact: []string{"example"},
CreatedBy: []string{"example"}, CreatedBy: []string{"example"},
TransferMode: []string{"auto"}, TransferMode: []string{"auto"},
},
Code: 200,
Status: "success",
Error: "",
} }
assert.Equal(t, expected, domain) assert.Equal(t, expected, domain)
} }
func TestClient_ListRecords(t *testing.T) { func TestClient_StatusDomain_error(t *testing.T) {
client := setupTest(t, "./fixtures/list-records.json") client := setupTest(t, "error.json")
resp, err := client.ListRecords("example", "example.com") _, err := client.StatusDomain(context.Background(), "example.com")
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}
func TestClient_ListRecords(t *testing.T) {
client := setupTest(t, "list-records.json")
resp, err := client.ListRecords(context.Background(), "example", "example.com")
require.NoError(t, err) require.NoError(t, err)
expected := &ListRecords{ expected := &ListRecordsResponse{
Response: ListRecordsResponse{
Limit: []int{100}, Limit: []int{100},
Column: []string{"rr"}, Column: []string{"rr"},
Count: []int{1}, Count: []int{1},
@ -97,17 +100,20 @@ func TestClient_ListRecords(t *testing.T) {
Total: []int{1}, Total: []int{1},
RR: []string{"example.com. 600 IN TXT txttxttxt"}, RR: []string{"example.com. 600 IN TXT txttxttxt"},
Last: []int{0}, Last: []int{0},
},
Code: 200,
Status: "success",
Error: "",
} }
assert.Equal(t, expected, resp) assert.Equal(t, expected, resp)
} }
func TestClient_ListRecords_error(t *testing.T) {
client := setupTest(t, "error.json")
_, err := client.ListRecords(context.Background(), "example", "example.com")
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}
func TestClient_AddRecord(t *testing.T) { func TestClient_AddRecord(t *testing.T) {
client := setupTest(t, "./fixtures/add-record.json") client := setupTest(t, "add-record.json")
testRecord := Record{ testRecord := Record{
ID: 2565, ID: 2565,
@ -116,7 +122,7 @@ func TestClient_AddRecord(t *testing.T) {
Content: "txttxttxt", Content: "txttxttxt",
TTL: 600, TTL: 600,
} }
resp, err := client.AddRecord("example.com", "test", "2565", testRecord) resp, err := client.AddRecord(context.Background(), "example.com", "test", "2565", testRecord)
require.NoError(t, err) require.NoError(t, err)
expected := &AddRecord{ expected := &AddRecord{
@ -133,17 +139,31 @@ func TestClient_AddRecord(t *testing.T) {
assert.Equal(t, expected, resp) assert.Equal(t, expected, resp)
} }
func TestClient_DeleteRecord(t *testing.T) { func TestClient_AddRecord_error(t *testing.T) {
client := setupTest(t, "./fixtures/delete-record.json") client := setupTest(t, "error.json")
resp, err := client.DeleteRecord("example.com", "test", "example.com 600 IN TXT txttxttxt", "2374") testRecord := Record{
require.NoError(t, err) ID: 2565,
Type: "TXT",
expected := &DeleteRecord{ Name: "example.com",
Code: 200, Content: "txttxttxt",
Status: "success", TTL: 600,
Error: "",
} }
assert.Equal(t, expected, resp) _, err := client.AddRecord(context.Background(), "example.com", "test", "2565", testRecord)
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
}
func TestClient_DeleteRecord(t *testing.T) {
client := setupTest(t, "delete-record.json")
err := client.DeleteRecord(context.Background(), "example.com", "test", "example.com 600 IN TXT txttxttxt", "2374")
require.NoError(t, err)
}
func TestClient_DeleteRecord_error(t *testing.T) {
client := setupTest(t, "error.json")
err := client.DeleteRecord(context.Background(), "example.com", "test", "example.com 600 IN TXT txttxttxt", "2374")
require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."})
} }

View file

@ -0,0 +1,5 @@
{
"code": 402,
"status": "error",
"error": "Invalid user."
}

View file

@ -2,8 +2,8 @@ package internal
import "fmt" import "fmt"
type StatusDomain struct { type Response[T any] struct {
Response StatusResponse `json:"response,omitempty"` Response T `json:"response,omitempty"`
Code int `json:"code"` Code int `json:"code"`
Status string `json:"status"` Status string `json:"status"`
Error string `json:"error"` Error string `json:"error"`
@ -36,13 +36,6 @@ type StatusResponse struct {
TransferMode []string `json:"transfermode"` TransferMode []string `json:"transfermode"`
} }
type ListRecords struct {
Response ListRecordsResponse `json:"response,omitempty"`
Code int `json:"code"`
Status string `json:"status"`
Error string `json:"error"`
}
type ListRecordsResponse struct { type ListRecordsResponse struct {
Limit []int `json:"limit,omitempty"` Limit []int `json:"limit,omitempty"`
Column []string `json:"column,omitempty"` Column []string `json:"column,omitempty"`
@ -83,9 +76,3 @@ type Record struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
TTL int `json:"ttl,omitempty"` // default 600 TTL int `json:"ttl,omitempty"` // default 600
} }
type DeleteRecord struct {
Code int `json:"code"`
Status string `json:"status"`
Error string `json:"error"`
}

View file

@ -2,15 +2,16 @@
package checkdomain package checkdomain
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
"github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env" "github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/checkdomain/internal"
) )
// Environment variables names. // Environment variables names.
@ -26,11 +27,6 @@ const (
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
) )
const (
defaultEndpoint = "https://api.checkdomain.de"
defaultTTL = 300
)
// Config is used to configure the creation of the DNSProvider. // Config is used to configure the creation of the DNSProvider.
type Config struct { type Config struct {
Endpoint *url.URL Endpoint *url.URL
@ -44,7 +40,7 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider. // NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), TTL: env.GetOrDefaultInt(EnvTTL, 300),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 5*time.Minute), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 5*time.Minute),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 7*time.Second), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 7*time.Second),
HTTPClient: &http.Client{ HTTPClient: &http.Client{
@ -56,9 +52,7 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface. // DNSProvider implements the challenge.Provider interface.
type DNSProvider struct { type DNSProvider struct {
config *Config config *Config
client *internal.Client
domainIDMu sync.Mutex
domainIDMapping map[string]int
} }
// NewDNSProvider returns a DNSProvider instance configured for CheckDomain. // NewDNSProvider returns a DNSProvider instance configured for CheckDomain.
@ -71,7 +65,7 @@ func NewDNSProvider() (*DNSProvider, error) {
config := NewDefaultConfig() config := NewDefaultConfig()
config.Token = values[EnvToken] config.Token = values[EnvToken]
endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, defaultEndpoint)) endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, internal.DefaultEndpoint))
if err != nil { if err != nil {
return nil, fmt.Errorf("checkdomain: invalid %s: %w", EnvEndpoint, err) return nil, fmt.Errorf("checkdomain: invalid %s: %w", EnvEndpoint, err)
} }
@ -89,32 +83,33 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("checkdomain: missing token") return nil, errors.New("checkdomain: missing token")
} }
if config.HTTPClient == nil { client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.Token))
config.HTTPClient = http.DefaultClient
if config.Endpoint != nil {
client.BaseURL = config.Endpoint
} }
return &DNSProvider{ return &DNSProvider{config: config, client: client}, nil
config: config,
domainIDMapping: make(map[string]int),
}, nil
} }
// Present creates a TXT record to fulfill the dns-01 challenge. // Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
ctx := context.Background()
// TODO(ldez) replace domain by FQDN to follow CNAME. // TODO(ldez) replace domain by FQDN to follow CNAME.
domainID, err := d.getDomainIDByName(domain) domainID, err := d.client.GetDomainIDByName(ctx, domain)
if err != nil { if err != nil {
return fmt.Errorf("checkdomain: %w", err) return fmt.Errorf("checkdomain: %w", err)
} }
err = d.checkNameservers(domainID) err = d.client.CheckNameservers(ctx, domainID)
if err != nil { if err != nil {
return fmt.Errorf("checkdomain: %w", err) return fmt.Errorf("checkdomain: %w", err)
} }
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
err = d.createRecord(domainID, &Record{ err = d.client.CreateRecord(ctx, domainID, &internal.Record{
Name: info.EffectiveFQDN, Name: info.EffectiveFQDN,
TTL: d.config.TTL, TTL: d.config.TTL,
Type: "TXT", Type: "TXT",
@ -130,28 +125,28 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
// CleanUp removes the TXT record previously created. // CleanUp removes the TXT record previously created.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
ctx := context.Background()
// TODO(ldez) replace domain by FQDN to follow CNAME. // TODO(ldez) replace domain by FQDN to follow CNAME.
domainID, err := d.getDomainIDByName(domain) domainID, err := d.client.GetDomainIDByName(ctx, domain)
if err != nil { if err != nil {
return fmt.Errorf("checkdomain: %w", err) return fmt.Errorf("checkdomain: %w", err)
} }
err = d.checkNameservers(domainID) err = d.client.CheckNameservers(ctx, domainID)
if err != nil { if err != nil {
return fmt.Errorf("checkdomain: %w", err) return fmt.Errorf("checkdomain: %w", err)
} }
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
err = d.deleteTXTRecord(domainID, info.EffectiveFQDN, info.Value) defer d.client.CleanCache(info.EffectiveFQDN)
err = d.client.DeleteTXTRecord(ctx, domainID, info.EffectiveFQDN, info.Value)
if err != nil { if err != nil {
return fmt.Errorf("checkdomain: %w", err) return fmt.Errorf("checkdomain: %w", err)
} }
d.domainIDMu.Lock()
delete(d.domainIDMapping, info.EffectiveFQDN)
d.domainIDMu.Unlock()
return nil return nil
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/go-acme/lego/v4/platform/tester" "github.com/go-acme/lego/v4/platform/tester"
"github.com/go-acme/lego/v4/providers/dns/checkdomain/internal"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -83,7 +84,7 @@ func TestNewDNSProviderConfig(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
config := NewDefaultConfig() config := NewDefaultConfig()
config.Endpoint, _ = url.Parse(defaultEndpoint) config.Endpoint, _ = url.Parse(internal.DefaultEndpoint)
if test.token != "" { if test.token != "" {
config.Token = test.token config.Token = test.token

View file

@ -1,416 +0,0 @@
package checkdomain
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
)
const (
ns1 = "ns.checkdomain.de"
ns2 = "ns2.checkdomain.de"
)
const domainNotFound = -1
// max page limit that the checkdomain api allows.
const maxLimit = 100
// max integer value.
const maxInt = int((^uint(0)) >> 1)
type (
// Some fields have been omitted from the structs
// because they are not required for this application.
DomainListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedDomainList `json:"_embedded"`
}
EmbeddedDomainList struct {
Domains []*Domain `json:"domains"`
}
Domain struct {
ID int `json:"id"`
Name string `json:"name"`
}
DomainResponse struct {
ID int `json:"id"`
Name string `json:"name"`
Created string `json:"created"`
PaidUp string `json:"payed_up"`
Active bool `json:"active"`
}
NameserverResponse struct {
General NameserverGeneral `json:"general"`
Nameservers []*Nameserver `json:"nameservers"`
SOA NameserverSOA `json:"soa"`
}
NameserverGeneral struct {
IPv4 string `json:"ip_v4"`
IPv6 string `json:"ip_v6"`
IncludeWWW bool `json:"include_www"`
}
NameserverSOA struct {
Mail string `json:"mail"`
Refresh int `json:"refresh"`
Retry int `json:"retry"`
Expiry int `json:"expiry"`
TTL int `json:"ttl"`
}
Nameserver struct {
Name string `json:"name"`
}
RecordListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedRecordList `json:"_embedded"`
}
EmbeddedRecordList struct {
Records []*Record `json:"records"`
}
Record struct {
Name string `json:"name"`
Value string `json:"value"`
TTL int `json:"ttl"`
Priority int `json:"priority"`
Type string `json:"type"`
}
)
func (d *DNSProvider) getDomainIDByName(name string) (int, error) {
// Load from cache if exists
d.domainIDMu.Lock()
id, ok := d.domainIDMapping[name]
d.domainIDMu.Unlock()
if ok {
return id, nil
}
// Find out by querying API
domains, err := d.listDomains()
if err != nil {
return domainNotFound, err
}
// Linear search over all registered domains
for _, domain := range domains {
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
d.domainIDMu.Lock()
d.domainIDMapping[name] = domain.ID
d.domainIDMu.Unlock()
return domain.ID, nil
}
}
return domainNotFound, errors.New("domain not found")
}
func (d *DNSProvider) listDomains() ([]*Domain, error) {
req, err := d.makeRequest(http.MethodGet, "/v1/domains", http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
// But that functionality is kinda broken,
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
q := req.URL.Query()
q.Set("limit", strconv.Itoa(maxLimit))
currentPage := 1
totalPages := maxInt
var domainList []*Domain
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
req.URL.RawQuery = q.Encode()
var res DomainListingResponse
if err := d.sendRequest(req, &res); err != nil {
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
}
// This is the first response,
// so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
domainList = make([]*Domain, 0, res.Total)
}
domainList = append(domainList, res.Embedded.Domains...)
currentPage++
}
return domainList, nil
}
func (d *DNSProvider) getNameserverInfo(domainID int) (*NameserverResponse, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers", domainID), http.NoBody)
if err != nil {
return nil, err
}
res := &NameserverResponse{}
if err := d.sendRequest(req, res); err != nil {
return nil, err
}
return res, nil
}
func (d *DNSProvider) checkNameservers(domainID int) error {
info, err := d.getNameserverInfo(domainID)
if err != nil {
return err
}
var found1, found2 bool
for _, item := range info.Nameservers {
switch item.Name {
case ns1:
found1 = true
case ns2:
found2 = true
}
}
if !found1 || !found2 {
return errors.New("not using checkdomain nameservers, can not update records")
}
return nil
}
func (d *DNSProvider) createRecord(domainID int, record *Record) error {
bs, err := json.Marshal(record)
if err != nil {
return fmt.Errorf("encoding record failed: %w", err)
}
req, err := d.makeRequest(http.MethodPost, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
if err != nil {
return err
}
return d.sendRequest(req, nil)
}
// Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
// TODO: Simplify this function once Checkdomain do provide the functionality.
func (d *DNSProvider) deleteTXTRecord(domainID int, recordName, recordValue string) error {
domainInfo, err := d.getDomainInfo(domainID)
if err != nil {
return err
}
nsInfo, err := d.getNameserverInfo(domainID)
if err != nil {
return err
}
allRecords, err := d.listRecords(domainID, "")
if err != nil {
return err
}
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
var recordsToKeep []*Record
// Find and delete matching records
for _, record := range allRecords {
if skipRecord(recordName, recordValue, record, nsInfo) {
continue
}
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
// The API Call to replace the records would fail if we wouldn't specify a value.
// Thus, we use the default TTL queried beforehand
if record.TTL == 0 {
record.TTL = nsInfo.SOA.TTL
}
recordsToKeep = append(recordsToKeep, record)
}
return d.replaceRecords(domainID, recordsToKeep)
}
func (d *DNSProvider) getDomainInfo(domainID int) (*DomainResponse, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d", domainID), http.NoBody)
if err != nil {
return nil, err
}
var res DomainResponse
err = d.sendRequest(req, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (d *DNSProvider) listRecords(domainID int, recordType string) ([]*Record, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
q := req.URL.Query()
q.Set("limit", strconv.Itoa(maxLimit))
if recordType != "" {
q.Set("type", recordType)
}
currentPage := 1
totalPages := maxInt
var recordList []*Record
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
req.URL.RawQuery = q.Encode()
var res RecordListingResponse
if err := d.sendRequest(req, &res); err != nil {
return nil, fmt.Errorf("failed to send record listing request: %w", err)
}
// This is the first response, so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
recordList = make([]*Record, 0, res.Total)
}
recordList = append(recordList, res.Embedded.Records...)
currentPage++
}
return recordList, nil
}
func (d *DNSProvider) replaceRecords(domainID int, records []*Record) error {
bs, err := json.Marshal(records)
if err != nil {
return fmt.Errorf("encoding record failed: %w", err)
}
req, err := d.makeRequest(http.MethodPut, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
if err != nil {
return err
}
return d.sendRequest(req, nil)
}
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
// Skip empty records
if record.Value == "" {
return true
}
// Skip some special records, otherwise we would get a "Nameserver update failed"
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
return true
}
nameMatch := recordName == "" || record.Name == recordName
valueMatch := recordValue == "" || record.Value == recordValue
// Skip our matching record
if record.Type == "TXT" && nameMatch && valueMatch {
return true
}
return false
}
func (d *DNSProvider) makeRequest(method, resource string, body io.Reader) (*http.Request, error) {
uri, err := d.config.Endpoint.Parse(resource)
if err != nil {
return nil, err
}
req, err := http.NewRequest(method, uri.String(), body)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+d.config.Token)
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}
func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error {
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
if err = checkResponse(resp); err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(raw, result)
if err != nil {
return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw))
}
return nil
}
func checkResponse(resp *http.Response) error {
if resp.StatusCode < http.StatusBadRequest {
return nil
}
if resp.Body == nil {
return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err)
}
return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw))
}

View file

@ -0,0 +1,383 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"golang.org/x/oauth2"
)
const (
ns1 = "ns.checkdomain.de"
ns2 = "ns2.checkdomain.de"
)
// DefaultEndpoint the default API endpoint.
const DefaultEndpoint = "https://api.checkdomain.de"
const domainNotFound = -1
// max page limit that the checkdomain api allows.
const maxLimit = 100
// max integer value.
const maxInt = int((^uint(0)) >> 1)
// Client the Autodns API client.
type Client struct {
domainIDMapping map[string]int
domainIDMu sync.Mutex
BaseURL *url.URL
httpClient *http.Client
}
// NewClient creates a new Client.
func NewClient(hc *http.Client) *Client {
baseURL, _ := url.Parse(DefaultEndpoint)
if hc == nil {
hc = &http.Client{Timeout: 10 * time.Second}
}
return &Client{
BaseURL: baseURL,
httpClient: hc,
domainIDMapping: make(map[string]int),
}
}
func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) {
// Load from cache if exists
c.domainIDMu.Lock()
id, ok := c.domainIDMapping[name]
c.domainIDMu.Unlock()
if ok {
return id, nil
}
// Find out by querying API
domains, err := c.listDomains(ctx)
if err != nil {
return domainNotFound, err
}
// Linear search over all registered domains
for _, domain := range domains {
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
c.domainIDMu.Lock()
c.domainIDMapping[name] = domain.ID
c.domainIDMu.Unlock()
return domain.ID, nil
}
}
return domainNotFound, errors.New("domain not found")
}
func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains")
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
// But that functionality is kinda broken,
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
q := endpoint.Query()
q.Set("limit", strconv.Itoa(maxLimit))
currentPage := 1
totalPages := maxInt
var domainList []*Domain
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
var res DomainListingResponse
if err := c.do(req, &res); err != nil {
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
}
// This is the first response,
// so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
domainList = make([]*Domain, 0, res.Total)
}
domainList = append(domainList, res.Embedded.Domains...)
currentPage++
}
return domainList, nil
}
func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
res := &NameserverResponse{}
if err := c.do(req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) CheckNameservers(ctx context.Context, domainID int) error {
info, err := c.getNameserverInfo(ctx, domainID)
if err != nil {
return err
}
var found1, found2 bool
for _, item := range info.Nameservers {
switch item.Name {
case ns1:
found1 = true
case ns2:
found2 = true
}
}
if !found1 || !found2 {
return errors.New("not using checkdomain nameservers, can not update records")
}
return nil
}
func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return err
}
return c.do(req, nil)
}
// DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
// TODO: Simplify this function once Checkdomain do provide the functionality.
func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error {
domainInfo, err := c.getDomainInfo(ctx, domainID)
if err != nil {
return err
}
nsInfo, err := c.getNameserverInfo(ctx, domainID)
if err != nil {
return err
}
allRecords, err := c.listRecords(ctx, domainID, "")
if err != nil {
return err
}
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
var recordsToKeep []*Record
// Find and delete matching records
for _, record := range allRecords {
if skipRecord(recordName, recordValue, record, nsInfo) {
continue
}
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
// The API Call to replace the records would fail if we wouldn't specify a value.
// Thus, we use the default TTL queried beforehand
if record.TTL == 0 {
record.TTL = nsInfo.SOA.TTL
}
recordsToKeep = append(recordsToKeep, record)
}
return c.replaceRecords(ctx, domainID, recordsToKeep)
}
func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID))
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
var res DomainResponse
err = c.do(req, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
q := endpoint.Query()
q.Set("limit", strconv.Itoa(maxLimit))
if recordType != "" {
q.Set("type", recordType)
}
currentPage := 1
totalPages := maxInt
var recordList []*Record
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
var res RecordListingResponse
if err := c.do(req, &res); err != nil {
return nil, fmt.Errorf("failed to send record listing request: %w", err)
}
// This is the first response, so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
recordList = make([]*Record, 0, res.Total)
}
recordList = append(recordList, res.Embedded.Records...)
currentPage++
}
return recordList, nil
}
func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records)
if err != nil {
return err
}
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func (c *Client) CleanCache(fqdn string) {
c.domainIDMu.Lock()
delete(c.domainIDMapping, fqdn)
c.domainIDMu.Unlock()
}
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
// Skip empty records
if record.Value == "" {
return true
}
// Skip some special records, otherwise we would get a "Nameserver update failed"
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
return true
}
nameMatch := recordName == "" || record.Name == recordName
valueMatch := recordValue == "" || record.Value == recordValue
// Skip our matching record
if record.Type == "TXT" && nameMatch && valueMatch {
return true
}
return false
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}
func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
if client == nil {
client = &http.Client{Timeout: 5 * time.Second}
}
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
Base: client.Transport,
}
return client
}

View file

@ -1,6 +1,8 @@
package checkdomain package internal
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -15,32 +17,42 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func setupTestProvider(t *testing.T) (*DNSProvider, *http.ServeMux) { func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper() t.Helper()
mux := http.NewServeMux() mux := http.NewServeMux()
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
t.Cleanup(server.Close) t.Cleanup(server.Close)
config := NewDefaultConfig() client := NewClient(OAuthStaticAccessToken(server.Client(), "secret"))
config.Endpoint, _ = url.Parse(server.URL) client.BaseURL, _ = url.Parse(server.URL)
config.Token = "secret"
p, err := NewDNSProviderConfig(config) return client, mux
require.NoError(t, err)
return p, mux
} }
func Test_getDomainIDByName(t *testing.T) { func checkAuthorizationHeader(req *http.Request) error {
prd, handler := setupTestProvider(t) val := req.Header.Get("Authorization")
if val != "Bearer secret" {
return fmt.Errorf("invalid header value, got: %s want %s", val, "Bearer secret")
}
return nil
}
handler.HandleFunc("/v1/domains", func(rw http.ResponseWriter, req *http.Request) { func TestClient_GetDomainIDByName(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/v1/domains", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return return
} }
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
domainList := DomainListingResponse{ domainList := DomainListingResponse{
Embedded: EmbeddedDomainList{Domains: []*Domain{ Embedded: EmbeddedDomainList{Domains: []*Domain{
{ID: 1, Name: "test.com"}, {ID: 1, Name: "test.com"},
@ -48,28 +60,34 @@ func Test_getDomainIDByName(t *testing.T) {
}}, }},
} }
err := json.NewEncoder(rw).Encode(domainList) err = json.NewEncoder(rw).Encode(domainList)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }
}) })
id, err := prd.getDomainIDByName("test.com") id, err := client.GetDomainIDByName(context.Background(), "test.com")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, id) assert.Equal(t, 1, id)
} }
func Test_checkNameservers(t *testing.T) { func TestClient_CheckNameservers(t *testing.T) {
prd, handler := setupTestProvider(t) client, mux := setupTest(t)
handler.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return return
} }
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
nsResp := NameserverResponse{ nsResp := NameserverResponse{
Nameservers: []*Nameserver{ Nameservers: []*Nameserver{
{Name: ns1}, {Name: ns1},
@ -78,33 +96,39 @@ func Test_checkNameservers(t *testing.T) {
}, },
} }
err := json.NewEncoder(rw).Encode(nsResp) err = json.NewEncoder(rw).Encode(nsResp)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }
}) })
err := prd.checkNameservers(1) err := client.CheckNameservers(context.Background(), 1)
require.NoError(t, err) require.NoError(t, err)
} }
func Test_createRecord(t *testing.T) { func TestClient_CreateRecord(t *testing.T) {
prd, handler := setupTestProvider(t) client, mux := setupTest(t)
handler.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost { if req.Method != http.MethodPost {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return return
} }
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
content, err := io.ReadAll(req.Body) content, err := io.ReadAll(req.Body)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest) http.Error(rw, err.Error(), http.StatusBadRequest)
return return
} }
if string(content) != `{"name":"test.com","value":"value","ttl":300,"priority":0,"type":"TXT"}` { if string(bytes.TrimSpace(content)) != `{"name":"test.com","value":"value","ttl":300,"priority":0,"type":"TXT"}` {
http.Error(rw, "invalid request body: "+string(content), http.StatusBadRequest) http.Error(rw, "invalid request body: "+string(content), http.StatusBadRequest)
return return
} }
@ -117,12 +141,12 @@ func Test_createRecord(t *testing.T) {
Value: "value", Value: "value",
} }
err := prd.createRecord(1, record) err := client.CreateRecord(context.Background(), 1, record)
require.NoError(t, err) require.NoError(t, err)
} }
func Test_deleteTXTRecord(t *testing.T) { func TestClient_DeleteTXTRecord(t *testing.T) {
prd, handler := setupTestProvider(t) client, mux := setupTest(t)
domainName := "lego.test" domainName := "lego.test"
recordValue := "test" recordValue := "test"
@ -158,20 +182,26 @@ func Test_deleteTXTRecord(t *testing.T) {
}, },
} }
handler.HandleFunc("/v1/domains/1", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/v1/domains/1", func(rw http.ResponseWriter, req *http.Request) {
err := checkAuthorizationHeader(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
resp := DomainResponse{ resp := DomainResponse{
ID: 1, ID: 1,
Name: domainName, Name: domainName,
} }
err := json.NewEncoder(rw).Encode(resp) err = json.NewEncoder(rw).Encode(resp)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }
}) })
handler.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest)
return return
@ -188,7 +218,7 @@ func Test_deleteTXTRecord(t *testing.T) {
} }
}) })
handler.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) {
switch req.Method { switch req.Method {
case http.MethodGet: case http.MethodGet:
resp := RecordListingResponse{ resp := RecordListingResponse{
@ -226,6 +256,6 @@ func Test_deleteTXTRecord(t *testing.T) {
}) })
info := dns01.GetChallengeInfo(domainName, "abc") info := dns01.GetChallengeInfo(domainName, "abc")
err := prd.deleteTXTRecord(1, info.EffectiveFQDN, recordValue) err := client.DeleteTXTRecord(context.Background(), 1, info.EffectiveFQDN, recordValue)
require.NoError(t, err) require.NoError(t, err)
} }

View file

@ -0,0 +1,73 @@
package internal
// Some fields have been omitted from the structs
// because they are not required for this application.
type DomainListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedDomainList `json:"_embedded"`
}
type EmbeddedDomainList struct {
Domains []*Domain `json:"domains"`
}
type Domain struct {
ID int `json:"id"`
Name string `json:"name"`
}
type DomainResponse struct {
ID int `json:"id"`
Name string `json:"name"`
Created string `json:"created"`
PaidUp string `json:"payed_up"`
Active bool `json:"active"`
}
type NameserverResponse struct {
General NameserverGeneral `json:"general"`
Nameservers []*Nameserver `json:"nameservers"`
SOA NameserverSOA `json:"soa"`
}
type NameserverGeneral struct {
IPv4 string `json:"ip_v4"`
IPv6 string `json:"ip_v6"`
IncludeWWW bool `json:"include_www"`
}
type NameserverSOA struct {
Mail string `json:"mail"`
Refresh int `json:"refresh"`
Retry int `json:"retry"`
Expiry int `json:"expiry"`
TTL int `json:"ttl"`
}
type Nameserver struct {
Name string `json:"name"`
}
type RecordListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedRecordList `json:"_embedded"`
}
type EmbeddedRecordList struct {
Records []*Record `json:"records"`
}
type Record struct {
Name string `json:"name"`
Value string `json:"value"`
TTL int `json:"ttl"`
Priority int `json:"priority"`
Type string `json:"type"`
}

View file

@ -93,11 +93,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := getZone(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("civo: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err) return fmt.Errorf("civo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
zone := dns01.UnFqdn(authZone)
dnsDomain, err := d.client.GetDNSDomain(zone) dnsDomain, err := d.client.GetDNSDomain(zone)
if err != nil { if err != nil {
return fmt.Errorf("civo: %w", err) return fmt.Errorf("civo: %w", err)
@ -125,11 +127,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := getZone(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("civo: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err) return fmt.Errorf("civo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
zone := dns01.UnFqdn(authZone)
dnsDomain, err := d.client.GetDNSDomain(zone) dnsDomain, err := d.client.GetDNSDomain(zone)
if err != nil { if err != nil {
return fmt.Errorf("civo: %w", err) return fmt.Errorf("civo: %w", err)
@ -166,12 +170,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval return d.config.PropagationTimeout, d.config.PollingInterval
} }
func getZone(fqdn string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
}
return dns01.UnFqdn(authZone), nil
}

View file

@ -2,6 +2,7 @@
package clouddns package clouddns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -89,10 +90,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
client.HTTPClient = config.HTTPClient client.HTTPClient = config.HTTPClient
} }
return &DNSProvider{ return &DNSProvider{client: client, config: config}, nil
client: client,
config: config,
}, nil
} }
// Timeout returns the timeout and interval to use when checking for DNS propagation. // Timeout returns the timeout and interval to use when checking for DNS propagation.
@ -107,12 +105,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("clouddns: %w", err) return fmt.Errorf("clouddns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
err = d.client.AddRecord(authZone, info.EffectiveFQDN, info.Value) ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil { if err != nil {
return fmt.Errorf("clouddns: %w", err) return err
}
err = d.client.AddRecord(ctx, authZone, info.EffectiveFQDN, info.Value)
if err != nil {
return fmt.Errorf("clouddns: add record: %w", err)
} }
return nil return nil
@ -124,12 +127,17 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("clouddns: %w", err) return fmt.Errorf("clouddns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
err = d.client.DeleteRecord(authZone, info.EffectiveFQDN) ctx, err := d.client.CreateAuthenticatedContext(context.Background())
if err != nil { if err != nil {
return fmt.Errorf("clouddns: %w", err) return err
}
err = d.client.DeleteRecord(ctx, authZone, info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("clouddns: delete record: %w", err)
} }
return nil return nil

View file

@ -2,117 +2,127 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const ( const apiBaseURL = "https://admin.vshosting.cloud/clouddns"
apiBaseURL = "https://admin.vshosting.cloud/clouddns"
loginURL = "https://admin.vshosting.cloud/api/public/auth/login" const authorizationHeader = "Authorization"
)
// Client handles all communication with CloudDNS API. // Client handles all communication with CloudDNS API.
type Client struct { type Client struct {
AccessToken string clientID string
ClientID string email string
Email string password string
Password string ttl int
TTL int
HTTPClient *http.Client
apiBaseURL string apiBaseURL *url.URL
loginURL string
loginURL *url.URL
HTTPClient *http.Client
} }
// NewClient returns a Client instance configured to handle CloudDNS API communication. // NewClient returns a Client instance configured to handle CloudDNS API communication.
func NewClient(clientID, email, password string, ttl int) *Client { func NewClient(clientID, email, password string, ttl int) *Client {
baseURL, _ := url.Parse(apiBaseURL)
loginBaseURL, _ := url.Parse(loginURL)
return &Client{ return &Client{
ClientID: clientID, clientID: clientID,
Email: email, email: email,
Password: password, password: password,
TTL: ttl, ttl: ttl,
HTTPClient: &http.Client{}, apiBaseURL: baseURL,
apiBaseURL: apiBaseURL, loginURL: loginBaseURL,
loginURL: loginURL, HTTPClient: &http.Client{Timeout: 5 * time.Second},
} }
} }
// AddRecord is a high level method to add a new record into CloudDNS zone. // AddRecord is a high level method to add a new record into CloudDNS zone.
func (c *Client) AddRecord(zone, recordName, recordValue string) error { func (c *Client) AddRecord(ctx context.Context, zone, recordName, recordValue string) error {
domain, err := c.getDomain(zone) domain, err := c.getDomain(ctx, zone)
if err != nil { if err != nil {
return err return err
} }
record := Record{DomainID: domain.ID, Name: recordName, Value: recordValue, Type: "TXT"} record := Record{DomainID: domain.ID, Name: recordName, Value: recordValue, Type: "TXT"}
err = c.addTxtRecord(record) err = c.addTxtRecord(ctx, record)
if err != nil { if err != nil {
return err return err
} }
return c.publishRecords(domain.ID) return c.publishRecords(ctx, domain.ID)
} }
// DeleteRecord is a high level method to remove a record from zone. // DeleteRecord is a high level method to remove a record from zone.
func (c *Client) DeleteRecord(zone, recordName string) error { func (c *Client) DeleteRecord(ctx context.Context, zone, recordName string) error {
domain, err := c.getDomain(zone) domain, err := c.getDomain(ctx, zone)
if err != nil { if err != nil {
return err return err
} }
record, err := c.getRecord(domain.ID, recordName) record, err := c.getRecord(ctx, domain.ID, recordName)
if err != nil { if err != nil {
return err return err
} }
err = c.deleteRecord(record) err = c.deleteRecord(ctx, record)
if err != nil { if err != nil {
return err return err
} }
return c.publishRecords(domain.ID) return c.publishRecords(ctx, domain.ID)
} }
func (c *Client) addTxtRecord(record Record) error { func (c *Client) addTxtRecord(ctx context.Context, record Record) error {
body, err := json.Marshal(record) endpoint := c.apiBaseURL.JoinPath("record-txt")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil { if err != nil {
return err return err
} }
_, err = c.doAPIRequest(http.MethodPost, "record-txt", bytes.NewReader(body)) return c.do(req, nil)
return err
} }
func (c *Client) deleteRecord(record Record) error { func (c *Client) deleteRecord(ctx context.Context, record Record) error {
endpoint := fmt.Sprintf("record/%s", record.ID) endpoint := c.apiBaseURL.JoinPath("record", record.ID)
_, err := c.doAPIRequest(http.MethodDelete, endpoint, nil)
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err return err
}
return c.do(req, nil)
} }
func (c *Client) getDomain(zone string) (Domain, error) { func (c *Client) getDomain(ctx context.Context, zone string) (Domain, error) {
searchQuery := SearchQuery{ searchQuery := SearchQuery{
Search: []Search{ Search: []Search{
{Name: "clientId", Operator: "eq", Value: c.ClientID}, {Name: "clientId", Operator: "eq", Value: c.clientID},
{Name: "domainName", Operator: "eq", Value: zone}, {Name: "domainName", Operator: "eq", Value: zone},
}, },
} }
body, err := json.Marshal(searchQuery) endpoint := c.apiBaseURL.JoinPath("domain", "search")
if err != nil {
return Domain{}, err
}
resp, err := c.doAPIRequest(http.MethodPost, "domain/search", bytes.NewReader(body)) req, err := newJSONRequest(ctx, http.MethodPost, endpoint, searchQuery)
if err != nil { if err != nil {
return Domain{}, err return Domain{}, err
} }
var result SearchResponse var result SearchResponse
err = json.Unmarshal(resp, &result) err = c.do(req, &result)
if err != nil { if err != nil {
return Domain{}, err return Domain{}, err
} }
@ -124,15 +134,16 @@ func (c *Client) getDomain(zone string) (Domain, error) {
return result.Items[0], nil return result.Items[0], nil
} }
func (c *Client) getRecord(domainID, recordName string) (Record, error) { func (c *Client) getRecord(ctx context.Context, domainID, recordName string) (Record, error) {
endpoint := fmt.Sprintf("domain/%s", domainID) endpoint := c.apiBaseURL.JoinPath("domain", domainID)
resp, err := c.doAPIRequest(http.MethodGet, endpoint, nil)
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return Record{}, err return Record{}, err
} }
var result DomainInfo var result DomainInfo
err = json.Unmarshal(resp, &result) err = c.do(req, &result)
if err != nil { if err != nil {
return Record{}, err return Record{}, err
} }
@ -146,116 +157,85 @@ func (c *Client) getRecord(domainID, recordName string) (Record, error) {
return Record{}, fmt.Errorf("record not found: domainID %s, name %s", domainID, recordName) return Record{}, fmt.Errorf("record not found: domainID %s, name %s", domainID, recordName)
} }
func (c *Client) publishRecords(domainID string) error { func (c *Client) publishRecords(ctx context.Context, domainID string) error {
body, err := json.Marshal(DomainInfo{SoaTTL: c.TTL}) endpoint := c.apiBaseURL.JoinPath("domain", domainID, "publish")
payload := DomainInfo{SoaTTL: c.ttl}
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, payload)
if err != nil { if err != nil {
return err return err
} }
endpoint := fmt.Sprintf("domain/%s/publish", domainID) return c.do(req, nil)
_, err = c.doAPIRequest(http.MethodPut, endpoint, bytes.NewReader(body))
return err
} }
func (c *Client) login() error { func (c *Client) do(req *http.Request, result any) error {
authorization := Authorization{Email: c.Email, Password: c.Password} at := getAccessToken(req.Context())
if at != "" {
body, err := json.Marshal(authorization) req.Header.Set(authorizationHeader, "Bearer "+at)
if err != nil {
return err
} }
req, err := http.NewRequest(http.MethodPost, c.loginURL, bytes.NewReader(body)) resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return err return errutils.NewHTTPDoError(req, err)
} }
req.Header.Set("Content-Type", "application/json") defer func() { _ = resp.Body.Close() }()
content, err := c.doRequest(req) if resp.StatusCode/100 != 2 {
if err != nil { return parseError(req, resp)
return err
} }
var result AuthResponse if result == nil {
err = json.Unmarshal(content, &result) return nil
if err != nil {
return err
} }
c.AccessToken = result.Auth.AccessToken raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil return nil
} }
func (c *Client) doAPIRequest(method, endpoint string, body io.Reader) ([]byte, error) { func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
if c.AccessToken == "" { buf := new(bytes.Buffer)
err := c.login()
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create request JSON body: %w", err)
} }
} }
url := fmt.Sprintf("%s/%s", c.apiBaseURL, endpoint) req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
req, err := c.newRequest(method, url, body)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("unable to create request: %w", err)
} }
content, err := c.doRequest(req) req.Header.Set("Accept", "application/json")
if err != nil {
return nil, err
}
return content, nil
}
func (c *Client) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
if payload != nil {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.AccessToken)) }
return req, nil return req, nil
} }
func (c *Client) doRequest(req *http.Request) ([]byte, error) { func parseError(req *http.Request, resp *http.Response) error {
resp, err := c.HTTPClient.Do(req) raw, _ := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest { var response APIError
return nil, readError(req, resp) err := json.Unmarshal(raw, &response)
if err != nil {
return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)
} }
content, err := io.ReadAll(resp.Body) return fmt.Errorf("[status code %d] %w", resp.StatusCode, response.Error)
if err != nil {
return nil, err
}
return content, nil
}
func readError(req *http.Request, resp *http.Response) error {
content, err := io.ReadAll(resp.Body)
if err != nil {
return errors.New(toUnreadableBodyMessage(req, content))
}
var errInfo APIError
err = json.Unmarshal(content, &errInfo)
if err != nil {
return fmt.Errorf("APIError unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content))
}
return fmt.Errorf("HTTP %d: code %v: %s", resp.StatusCode, errInfo.Error.Code, errInfo.Error.Message)
}
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
} }

View file

@ -1,16 +1,33 @@
package internal package internal
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestClient_AddRecord(t *testing.T) { func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux() mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient("clientID", "email@example.com", "secret", 300)
client.HTTPClient = server.Client()
client.apiBaseURL, _ = url.Parse(server.URL + "/api")
client.loginURL, _ = url.Parse(server.URL + "/login")
return client, mux
}
func TestClient_AddRecord(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) {
response := SearchResponse{ response := SearchResponse{
@ -45,19 +62,12 @@ func TestClient_AddRecord(t *testing.T) {
} }
}) })
server := httptest.NewServer(mux) err := client.AddRecord(context.Background(), "example.com", "_acme-challenge.example.com", "txt")
t.Cleanup(server.Close)
client := NewClient("clientID", "email@example.com", "secret", 300)
client.apiBaseURL = server.URL + "/api"
client.loginURL = server.URL + "/login"
err := client.AddRecord("example.com", "_acme-challenge.example.com", "txt")
require.NoError(t, err) require.NoError(t, err)
} }
func TestClient_DeleteRecord(t *testing.T) { func TestClient_DeleteRecord(t *testing.T) {
mux := http.NewServeMux() client, mux := setupTest(t)
mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) { mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) {
response := SearchResponse{ response := SearchResponse{
@ -114,13 +124,9 @@ func TestClient_DeleteRecord(t *testing.T) {
} }
}) })
server := httptest.NewServer(mux) ctx, err := client.CreateAuthenticatedContext(context.Background())
t.Cleanup(server.Close) require.NoError(t, err)
client := NewClient("clientID", "email@example.com", "secret", 300) err = client.DeleteRecord(ctx, "example.com", "_acme-challenge.example.com")
client.apiBaseURL = server.URL + "/api"
client.loginURL = server.URL + "/login"
err := client.DeleteRecord("example.com", "_acme-challenge.example.com")
require.NoError(t, err) require.NoError(t, err)
} }

View file

@ -0,0 +1,47 @@
package internal
import (
"context"
"net/http"
)
const loginURL = "https://admin.vshosting.cloud/api/public/auth/login"
type token string
const accessTokenKey token = "accessToken"
func (c *Client) login(ctx context.Context) (*AuthResponse, error) {
authorization := Authorization{Email: c.email, Password: c.password}
req, err := newJSONRequest(ctx, http.MethodPost, c.loginURL, authorization)
if err != nil {
return nil, err
}
var result AuthResponse
err = c.do(req, &result)
if err != nil {
return nil, err
}
return &result, nil
}
func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) {
tok, err := c.login(ctx)
if err != nil {
return nil, err
}
return context.WithValue(ctx, accessTokenKey, tok.Auth.AccessToken), nil
}
func getAccessToken(ctx context.Context) string {
tok, ok := ctx.Value(accessTokenKey).(string)
if !ok {
return ""
}
return tok
}

View file

@ -0,0 +1,46 @@
package internal
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient_CreateAuthenticatedContext(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/login", func(rw http.ResponseWriter, req *http.Request) {
response := AuthResponse{
Auth: Auth{
AccessToken: "at",
RefreshToken: "",
},
}
err := json.NewEncoder(rw).Encode(response)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/api/record/xxx", func(rw http.ResponseWriter, req *http.Request) {
authorization := req.Header.Get(authorizationHeader)
if authorization != "Bearer at" {
http.Error(rw, "invalid credential: "+authorization, http.StatusUnauthorized)
return
}
})
ctx, err := client.CreateAuthenticatedContext(context.Background())
require.NoError(t, err)
at := getAccessToken(ctx)
assert.Equal(t, "at", at)
err = client.deleteRecord(ctx, Record{ID: "xxx"})
require.NoError(t, err)
}

View file

@ -1,5 +1,7 @@
package internal package internal
import "fmt"
type APIError struct { type APIError struct {
Error ErrorContent `json:"error"` Error ErrorContent `json:"error"`
} }
@ -9,6 +11,10 @@ type ErrorContent struct {
Message string `json:"message,omitempty"` Message string `json:"message,omitempty"`
} }
func (e ErrorContent) Error() string {
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
type Authorization struct { type Authorization struct {
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`

View file

@ -126,7 +126,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("cloudflare: %w", err) return fmt.Errorf("cloudflare: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
zoneID, err := d.client.ZoneIDByName(authZone) zoneID, err := d.client.ZoneIDByName(authZone)
@ -165,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("cloudflare: %w", err) return fmt.Errorf("cloudflare: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
zoneID, err := d.client.ZoneIDByName(authZone) zoneID, err := d.client.ZoneIDByName(authZone)

View file

@ -2,6 +2,7 @@
package cloudns package cloudns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -104,29 +105,33 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := d.client.GetZone(info.EffectiveFQDN) ctx := context.Background()
zone, err := d.client.GetZone(ctx, info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("ClouDNS: %w", err) return fmt.Errorf("ClouDNS: %w", err)
} }
err = d.client.AddTxtRecord(zone.Name, info.EffectiveFQDN, info.Value, d.config.TTL) err = d.client.AddTxtRecord(ctx, zone.Name, info.EffectiveFQDN, info.Value, d.config.TTL)
if err != nil { if err != nil {
return fmt.Errorf("ClouDNS: %w", err) return fmt.Errorf("ClouDNS: %w", err)
} }
return d.waitNameservers(domain, zone) return d.waitNameservers(ctx, domain, zone)
} }
// CleanUp removes the TXT records matching the specified parameters. // CleanUp removes the TXT records matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
zone, err := d.client.GetZone(info.EffectiveFQDN) ctx := context.Background()
zone, err := d.client.GetZone(ctx, info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("ClouDNS: %w", err) return fmt.Errorf("ClouDNS: %w", err)
} }
records, err := d.client.ListTxtRecords(zone.Name, info.EffectiveFQDN) records, err := d.client.ListTxtRecords(ctx, zone.Name, info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("ClouDNS: %w", err) return fmt.Errorf("ClouDNS: %w", err)
} }
@ -136,7 +141,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
} }
for _, record := range records { for _, record := range records {
err = d.client.RemoveTxtRecord(record.ID, zone.Name) err = d.client.RemoveTxtRecord(ctx, record.ID, zone.Name)
if err != nil { if err != nil {
return fmt.Errorf("ClouDNS: %w", err) return fmt.Errorf("ClouDNS: %w", err)
} }
@ -153,9 +158,9 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
// waitNameservers At the time of writing 4 servers are found as authoritative, but 8 are reported during the sync. // waitNameservers At the time of writing 4 servers are found as authoritative, but 8 are reported during the sync.
// If this is not done, the secondary verification done by Let's Encrypt server will fail quire a bit. // If this is not done, the secondary verification done by Let's Encrypt server will fail quire a bit.
func (d *DNSProvider) waitNameservers(domain string, zone *internal.Zone) error { func (d *DNSProvider) waitNameservers(ctx context.Context, domain string, zone *internal.Zone) error {
return wait.For("Nameserver sync on "+domain, d.config.PropagationTimeout, d.config.PollingInterval, func() (bool, error) { return wait.For("Nameserver sync on "+domain, d.config.PropagationTimeout, d.config.PollingInterval, func() (bool, error) {
syncProgress, err := d.client.GetUpdateStatus(zone.Name) syncProgress, err := d.client.GetUpdateStatus(ctx, zone.Name)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -8,8 +9,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"time"
"github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const defaultBaseURL = "https://api.cloudns.net/dns/" const defaultBaseURL = "https://api.cloudns.net/dns/"
@ -19,8 +22,9 @@ type Client struct {
authID string authID string
subAuthID string subAuthID string
authPassword string authPassword string
HTTPClient *http.Client
BaseURL *url.URL BaseURL *url.URL
HTTPClient *http.Client
} }
// NewClient creates a ClouDNS client. // NewClient creates a ClouDNS client.
@ -42,16 +46,16 @@ func NewClient(authID, subAuthID, authPassword string) (*Client, error) {
authID: authID, authID: authID,
subAuthID: subAuthID, subAuthID: subAuthID,
authPassword: authPassword, authPassword: authPassword,
HTTPClient: &http.Client{},
BaseURL: baseURL, BaseURL: baseURL,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil }, nil
} }
// GetZone Get domain name information for a FQDN. // GetZone Get domain name information for a FQDN.
func (c *Client) GetZone(authFQDN string) (*Zone, error) { func (c *Client) GetZone(ctx context.Context, authFQDN string) (*Zone, error) {
authZone, err := dns01.FindZoneByFqdn(authFQDN) authZone, err := dns01.FindZoneByFqdn(authFQDN)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not find zone for FQDN %q: %w", authFQDN, err)
} }
authZoneName := dns01.UnFqdn(authZone) authZoneName := dns01.UnFqdn(authZone)
@ -62,16 +66,21 @@ func (c *Client) GetZone(authFQDN string) (*Zone, error) {
q.Set("domain-name", authZoneName) q.Set("domain-name", authZoneName)
endpoint.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint) req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var zone Zone var zone Zone
if len(result) > 0 { if len(rawMessage) > 0 {
if err = json.Unmarshal(result, &zone); err != nil { if err = json.Unmarshal(rawMessage, &zone); err != nil {
return nil, fmt.Errorf("failed to unmarshal zone: %w", err) return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
} }
} }
@ -83,7 +92,7 @@ func (c *Client) GetZone(authFQDN string) (*Zone, error) {
} }
// FindTxtRecord returns the TXT record a zone ID and a FQDN. // FindTxtRecord returns the TXT record a zone ID and a FQDN.
func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) { func (c *Client) FindTxtRecord(ctx context.Context, zoneName, fqdn string) (*TXTRecord, error) {
subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName) subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -97,19 +106,24 @@ func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) {
q.Set("type", "TXT") q.Set("type", "TXT")
endpoint.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint) req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// the API returns [] when there is no records. // the API returns [] when there is no records.
if string(result) == "[]" { if string(rawMessage) == "[]" {
return nil, nil return nil, nil
} }
var records map[string]TXTRecord var records map[string]TXTRecord
if err = json.Unmarshal(result, &records); err != nil { if err = json.Unmarshal(rawMessage, &records); err != nil {
return nil, fmt.Errorf("failed to unmarshall TXT records: %w: %s", err, string(result)) return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
} }
for _, record := range records { for _, record := range records {
@ -122,7 +136,7 @@ func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) {
} }
// ListTxtRecords returns the TXT records a zone ID and a FQDN. // ListTxtRecords returns the TXT records a zone ID and a FQDN.
func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) { func (c *Client) ListTxtRecords(ctx context.Context, zoneName, fqdn string) ([]TXTRecord, error) {
subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName) subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -136,19 +150,24 @@ func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) {
q.Set("type", "TXT") q.Set("type", "TXT")
endpoint.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint) req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// the API returns [] when there is no records. // the API returns [] when there is no records.
if string(result) == "[]" { if string(rawMessage) == "[]" {
return nil, nil return nil, nil
} }
var raw map[string]TXTRecord var raw map[string]TXTRecord
if err = json.Unmarshal(result, &raw); err != nil { if err = json.Unmarshal(rawMessage, &raw); err != nil {
return nil, fmt.Errorf("failed to unmarshall TXT records: %w: %s", err, string(result)) return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
} }
var records []TXTRecord var records []TXTRecord
@ -162,7 +181,7 @@ func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) {
} }
// AddTxtRecord adds a TXT record. // AddTxtRecord adds a TXT record.
func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error { func (c *Client) AddTxtRecord(ctx context.Context, zoneName, fqdn, value string, ttl int) error {
subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName) subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName)
if err != nil { if err != nil {
return err return err
@ -178,14 +197,19 @@ func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error {
q.Set("record-type", "TXT") q.Set("record-type", "TXT")
endpoint.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
raw, err := c.doRequest(http.MethodPost, endpoint) req, err := c.newRequest(ctx, http.MethodPost, endpoint)
if err != nil {
return err
}
rawMessage, err := c.do(req)
if err != nil { if err != nil {
return err return err
} }
resp := apiResponse{} resp := apiResponse{}
if err = json.Unmarshal(raw, &resp); err != nil { if err = json.Unmarshal(rawMessage, &resp); err != nil {
return fmt.Errorf("failed to unmarshal API response: %w: %s", err, string(raw)) return errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
} }
if resp.Status != "Success" { if resp.Status != "Success" {
@ -196,7 +220,7 @@ func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error {
} }
// RemoveTxtRecord removes a TXT record. // RemoveTxtRecord removes a TXT record.
func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error { func (c *Client) RemoveTxtRecord(ctx context.Context, recordID int, zoneName string) error {
endpoint := c.BaseURL.JoinPath("delete-record.json") endpoint := c.BaseURL.JoinPath("delete-record.json")
q := endpoint.Query() q := endpoint.Query()
@ -204,14 +228,19 @@ func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error {
q.Set("record-id", strconv.Itoa(recordID)) q.Set("record-id", strconv.Itoa(recordID))
endpoint.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
raw, err := c.doRequest(http.MethodPost, endpoint) req, err := c.newRequest(ctx, http.MethodPost, endpoint)
if err != nil {
return err
}
rawMessage, err := c.do(req)
if err != nil { if err != nil {
return err return err
} }
resp := apiResponse{} resp := apiResponse{}
if err = json.Unmarshal(raw, &resp); err != nil { if err = json.Unmarshal(rawMessage, &resp); err != nil {
return fmt.Errorf("failed to unmarshal API response: %w: %s", err, string(raw)) return errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
} }
if resp.Status != "Success" { if resp.Status != "Success" {
@ -222,26 +251,31 @@ func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error {
} }
// GetUpdateStatus gets sync progress of all CloudDNS NS servers. // GetUpdateStatus gets sync progress of all CloudDNS NS servers.
func (c *Client) GetUpdateStatus(zoneName string) (*SyncProgress, error) { func (c *Client) GetUpdateStatus(ctx context.Context, zoneName string) (*SyncProgress, error) {
endpoint := c.BaseURL.JoinPath("update-status.json") endpoint := c.BaseURL.JoinPath("update-status.json")
q := endpoint.Query() q := endpoint.Query()
q.Set("domain-name", zoneName) q.Set("domain-name", zoneName)
endpoint.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
result, err := c.doRequest(http.MethodGet, endpoint) req, err := c.newRequest(ctx, http.MethodGet, endpoint)
if err != nil {
return nil, err
}
rawMessage, err := c.do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// the API returns [] when there is no records. // the API returns [] when there is no records.
if string(result) == "[]" { if string(rawMessage) == "[]" {
return nil, errors.New("no nameservers records returned") return nil, errors.New("no nameservers records returned")
} }
var records []UpdateRecord var records []UpdateRecord
if err = json.Unmarshal(result, &records); err != nil { if err = json.Unmarshal(rawMessage, &records); err != nil {
return nil, fmt.Errorf("failed to unmarshal UpdateRecord: %w: %s", err, string(result)) return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err)
} }
updatedCount := 0 updatedCount := 0
@ -254,33 +288,8 @@ func (c *Client) GetUpdateStatus(zoneName string) (*SyncProgress, error) {
return &SyncProgress{Complete: updatedCount == len(records), Updated: updatedCount, Total: len(records)}, nil return &SyncProgress{Complete: updatedCount == len(records), Updated: updatedCount, Total: len(records)}, nil
} }
func (c *Client) doRequest(method string, uri *url.URL) (json.RawMessage, error) { func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL) (*http.Request, error) {
req, err := c.buildRequest(method, uri) q := endpoint.Query()
if err != nil {
return nil, err
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.New(toUnreadableBodyMessage(req, content))
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("invalid code (%d), error: %s", resp.StatusCode, content)
}
return content, nil
}
func (c *Client) buildRequest(method string, uri *url.URL) (*http.Request, error) {
q := uri.Query()
if c.subAuthID != "" { if c.subAuthID != "" {
q.Set("sub-auth-id", c.subAuthID) q.Set("sub-auth-id", c.subAuthID)
@ -290,18 +299,34 @@ func (c *Client) buildRequest(method string, uri *url.URL) (*http.Request, error
q.Set("auth-password", c.authPassword) q.Set("auth-password", c.authPassword)
uri.RawQuery = q.Encode() endpoint.RawQuery = q.Encode()
req, err := http.NewRequest(method, uri.String(), nil) req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
return req, nil return req, nil
} }
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { func (c *Client) do(req *http.Request) (json.RawMessage, error) {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody)) resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errutils.NewReadResponseError(req, resp.StatusCode, err)
}
return raw, nil
} }
// Rounds the given TTL in seconds to the next accepted value. // Rounds the given TTL in seconds to the next accepted value.

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -11,6 +12,21 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func setupTest(t *testing.T, subAuthID string, handler http.HandlerFunc) *Client {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client, err := NewClient("myAuthID", subAuthID, "myAuthPassword")
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
client.HTTPClient = server.Client()
return client
}
func handlerMock(method string, jsonData []byte) http.HandlerFunc { func handlerMock(method string, jsonData []byte) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) { return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method { if req.Method != method {
@ -109,22 +125,16 @@ func TestClient_GetZone(t *testing.T) {
authFQDN: "_acme-challenge.foo.com.", authFQDN: "_acme-challenge.foo.com.",
apiResponse: `[{}]`, apiResponse: `[{}]`,
expected: expected{ expected: expected{
errorMsg: "failed to unmarshal zone: json: cannot unmarshal array into Go value of type internal.Zone", errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.Zone",
}, },
}, },
} }
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse))) client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse)))
t.Cleanup(server.Close)
client, err := NewClient("myAuthID", "", "myAuthPassword") zone, err := client.GetZone(context.Background(), test.authFQDN)
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
zone, err := client.GetZone(test.authFQDN)
if test.expected.errorMsg != "" { if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg) require.EqualError(t, err, test.expected.errorMsg)
@ -222,22 +232,16 @@ func TestClient_FindTxtRecord(t *testing.T) {
zoneName: "example.com", zoneName: "example.com",
apiResponse: `[{}]`, apiResponse: `[{}]`,
expected: expected{ expected: expected{
errorMsg: "failed to unmarshall TXT records: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord: [{}]", errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord",
}, },
}, },
} }
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse))) client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse)))
t.Cleanup(server.Close)
client, err := NewClient("myAuthID", "", "myAuthPassword") txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneName, test.authFQDN)
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
txtRecord, err := client.FindTxtRecord(test.zoneName, test.authFQDN)
if test.expected.errorMsg != "" { if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg) require.EqualError(t, err, test.expected.errorMsg)
@ -337,22 +341,16 @@ func TestClient_ListTxtRecord(t *testing.T) {
zoneName: "example.com", zoneName: "example.com",
apiResponse: `[{}]`, apiResponse: `[{}]`,
expected: expected{ expected: expected{
errorMsg: "failed to unmarshall TXT records: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord: [{}]", errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord",
}, },
}, },
} }
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse))) client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse)))
t.Cleanup(server.Close)
client, err := NewClient("myAuthID", "", "myAuthPassword") txtRecords, err := client.ListTxtRecords(context.Background(), test.zoneName, test.authFQDN)
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
txtRecords, err := client.ListTxtRecords(test.zoneName, test.authFQDN)
if test.expected.errorMsg != "" { if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg) require.EqualError(t, err, test.expected.errorMsg)
@ -440,14 +438,14 @@ func TestClient_AddTxtRecord(t *testing.T) {
apiResponse: `[{}]`, apiResponse: `[{}]`,
expected: expected{ expected: expected{
query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=bar.com&host=_acme-challenge&record=TXTtxtTXTtxtTXTtxtTXTtxt&record-type=TXT&ttl=300`, query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=bar.com&host=_acme-challenge&record=TXTtxtTXTtxtTXTtxtTXTtxt&record-type=TXT&ttl=300`,
errorMsg: "failed to unmarshal API response: json: cannot unmarshal array into Go value of type internal.apiResponse: [{}]", errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.apiResponse",
}, },
}, },
} }
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { client := setupTest(t, test.subAuthID, func(rw http.ResponseWriter, req *http.Request) {
if test.expected.query != req.URL.RawQuery { if test.expected.query != req.URL.RawQuery {
msg := fmt.Sprintf("got: %s, want: %s", test.expected.query, req.URL.RawQuery) msg := fmt.Sprintf("got: %s, want: %s", test.expected.query, req.URL.RawQuery)
http.Error(rw, msg, http.StatusBadRequest) http.Error(rw, msg, http.StatusBadRequest)
@ -455,15 +453,9 @@ func TestClient_AddTxtRecord(t *testing.T) {
} }
handlerMock(http.MethodPost, []byte(test.apiResponse))(rw, req) handlerMock(http.MethodPost, []byte(test.apiResponse))(rw, req)
})) })
t.Cleanup(server.Close)
client, err := NewClient(test.authID, test.subAuthID, "myAuthPassword") err := client.AddTxtRecord(context.Background(), test.zoneName, test.authFQDN, test.value, test.ttl)
require.NoError(t, err)
client.BaseURL, _ = url.Parse(server.URL)
err = client.AddTxtRecord(test.zoneName, test.authFQDN, test.value, test.ttl)
if test.expected.errorMsg != "" { if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg) require.EqualError(t, err, test.expected.errorMsg)
@ -513,7 +505,7 @@ func TestClient_RemoveTxtRecord(t *testing.T) {
apiResponse: `[{}]`, apiResponse: `[{}]`,
expected: expected{ expected: expected{
query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=foo-plus.com&record-id=44`, query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=foo-plus.com&record-id=44`,
errorMsg: "failed to unmarshal API response: json: cannot unmarshal array into Go value of type internal.apiResponse: [{}]", errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.apiResponse",
}, },
}, },
} }
@ -536,7 +528,7 @@ func TestClient_RemoveTxtRecord(t *testing.T) {
client.BaseURL, _ = url.Parse(server.URL) client.BaseURL, _ = url.Parse(server.URL)
err = client.RemoveTxtRecord(test.id, test.zoneName) err = client.RemoveTxtRecord(context.Background(), test.id, test.zoneName)
if test.expected.errorMsg != "" { if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg) require.EqualError(t, err, test.expected.errorMsg)
@ -592,7 +584,7 @@ func TestClient_GetUpdateStatus(t *testing.T) {
authFQDN: "_acme-challenge.foo.com.", authFQDN: "_acme-challenge.foo.com.",
zoneName: "test-zone", zoneName: "test-zone",
apiResponse: `[x]`, apiResponse: `[x]`,
expected: expected{errorMsg: "failed to unmarshal UpdateRecord: invalid character 'x' looking for beginning of value: [x]"}, expected: expected{errorMsg: "unable to unmarshal response: [status code: 200] body: [x] error: invalid character 'x' looking for beginning of value"},
}, },
} }
@ -606,7 +598,7 @@ func TestClient_GetUpdateStatus(t *testing.T) {
client.BaseURL, _ = url.Parse(server.URL) client.BaseURL, _ = url.Parse(server.URL)
syncProgress, err := client.GetUpdateStatus(test.zoneName) syncProgress, err := client.GetUpdateStatus(context.Background(), test.zoneName)
if test.expected.errorMsg != "" { if test.expected.errorMsg != "" {
require.EqualError(t, err, test.expected.errorMsg) require.EqualError(t, err, test.expected.errorMsg)

View file

@ -2,6 +2,7 @@
package cloudxns package cloudxns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -59,7 +60,7 @@ type DNSProvider struct {
func NewDNSProvider() (*DNSProvider, error) { func NewDNSProvider() (*DNSProvider, error) {
values, err := env.Get(EnvAPIKey, EnvSecretKey) values, err := env.Get(EnvAPIKey, EnvSecretKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: %w", err) return nil, fmt.Errorf("cloudxns: %w", err)
} }
config := NewDefaultConfig() config := NewDefaultConfig()
@ -72,15 +73,17 @@ func NewDNSProvider() (*DNSProvider, error) {
// NewDNSProviderConfig return a DNSProvider instance configured for CloudXNS. // NewDNSProviderConfig return a DNSProvider instance configured for CloudXNS.
func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if config == nil { if config == nil {
return nil, errors.New("CloudXNS: the configuration of the DNS provider is nil") return nil, errors.New("cloudxns: the configuration of the DNS provider is nil")
} }
client, err := internal.NewClient(config.APIKey, config.SecretKey) client, err := internal.NewClient(config.APIKey, config.SecretKey)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("cloudxns: %w", err)
} }
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient client.HTTPClient = config.HTTPClient
}
return &DNSProvider{client: client, config: config}, nil return &DNSProvider{client: client, config: config}, nil
} }
@ -89,29 +92,43 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
challengeInfo := dns01.GetChallengeInfo(domain, keyAuth) challengeInfo := dns01.GetChallengeInfo(domain, keyAuth)
info, err := d.client.GetDomainInformation(challengeInfo.EffectiveFQDN) ctx := context.Background()
info, err := d.client.GetDomainInformation(ctx, challengeInfo.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("cloudxns: %w", err)
} }
return d.client.AddTxtRecord(info, challengeInfo.EffectiveFQDN, challengeInfo.Value, d.config.TTL) err = d.client.AddTxtRecord(ctx, info, challengeInfo.EffectiveFQDN, challengeInfo.Value, d.config.TTL)
if err != nil {
return fmt.Errorf("cloudxns: %w", err)
}
return nil
} }
// CleanUp removes the TXT record matching the specified parameters. // CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
challengeInfo := dns01.GetChallengeInfo(domain, keyAuth) challengeInfo := dns01.GetChallengeInfo(domain, keyAuth)
info, err := d.client.GetDomainInformation(challengeInfo.EffectiveFQDN) ctx := context.Background()
info, err := d.client.GetDomainInformation(ctx, challengeInfo.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("cloudxns: %w", err)
} }
record, err := d.client.FindTxtRecord(info.ID, challengeInfo.EffectiveFQDN) record, err := d.client.FindTxtRecord(ctx, info.ID, challengeInfo.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("cloudxns: %w", err)
} }
return d.client.RemoveTxtRecord(record.RecordID, info.ID) err = d.client.RemoveTxtRecord(ctx, record.RecordID, info.ID)
if err != nil {
return fmt.Errorf("cloudxns: %w", err)
}
return nil
} }
// Timeout returns the timeout and interval to use when checking for DNS propagation. // Timeout returns the timeout and interval to use when checking for DNS propagation.

View file

@ -34,7 +34,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIKey: "", EnvAPIKey: "",
EnvSecretKey: "", EnvSecretKey: "",
}, },
expected: "CloudXNS: some credentials information are missing: CLOUDXNS_API_KEY,CLOUDXNS_SECRET_KEY", expected: "cloudxns: some credentials information are missing: CLOUDXNS_API_KEY,CLOUDXNS_SECRET_KEY",
}, },
{ {
desc: "missing API key", desc: "missing API key",
@ -42,7 +42,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIKey: "", EnvAPIKey: "",
EnvSecretKey: "456", EnvSecretKey: "456",
}, },
expected: "CloudXNS: some credentials information are missing: CLOUDXNS_API_KEY", expected: "cloudxns: some credentials information are missing: CLOUDXNS_API_KEY",
}, },
{ {
desc: "missing secret key", desc: "missing secret key",
@ -50,7 +50,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIKey: "123", EnvAPIKey: "123",
EnvSecretKey: "", EnvSecretKey: "",
}, },
expected: "CloudXNS: some credentials information are missing: CLOUDXNS_SECRET_KEY", expected: "cloudxns: some credentials information are missing: CLOUDXNS_SECRET_KEY",
}, },
} }
@ -89,17 +89,17 @@ func TestNewDNSProviderConfig(t *testing.T) {
}, },
{ {
desc: "missing credentials", desc: "missing credentials",
expected: "CloudXNS: credentials missing: apiKey", expected: "cloudxns: credentials missing: apiKey",
}, },
{ {
desc: "missing api key", desc: "missing api key",
secretKey: "456", secretKey: "456",
expected: "CloudXNS: credentials missing: apiKey", expected: "cloudxns: credentials missing: apiKey",
}, },
{ {
desc: "missing secret key", desc: "missing secret key",
apiKey: "123", apiKey: "123",
expected: "CloudXNS: credentials missing: secretKey", expected: "cloudxns: credentials missing: secretKey",
}, },
} }

View file

@ -2,6 +2,7 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -9,83 +10,63 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"time" "time"
"github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const defaultBaseURL = "https://www.cloudxns.net/api2/" const defaultBaseURL = "https://www.cloudxns.net/api2/"
type apiResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
// Data Domain information.
type Data struct {
ID string `json:"id"`
Domain string `json:"domain"`
TTL int `json:"ttl,omitempty"`
}
// TXTRecord a TXT record.
type TXTRecord struct {
ID int `json:"domain_id,omitempty"`
RecordID string `json:"record_id,omitempty"`
Host string `json:"host"`
Value string `json:"value"`
Type string `json:"type"`
LineID int `json:"line_id,string"`
TTL int `json:"ttl,string"`
}
// NewClient creates a CloudXNS client.
func NewClient(apiKey, secretKey string) (*Client, error) {
if apiKey == "" {
return nil, errors.New("CloudXNS: credentials missing: apiKey")
}
if secretKey == "" {
return nil, errors.New("CloudXNS: credentials missing: secretKey")
}
return &Client{
apiKey: apiKey,
secretKey: secretKey,
HTTPClient: &http.Client{},
BaseURL: defaultBaseURL,
}, nil
}
// Client CloudXNS client. // Client CloudXNS client.
type Client struct { type Client struct {
apiKey string apiKey string
secretKey string secretKey string
baseURL *url.URL
HTTPClient *http.Client HTTPClient *http.Client
BaseURL string }
// NewClient creates a CloudXNS client.
func NewClient(apiKey, secretKey string) (*Client, error) {
if apiKey == "" {
return nil, errors.New("credentials missing: apiKey")
}
if secretKey == "" {
return nil, errors.New("credentials missing: secretKey")
}
baseURL, _ := url.Parse(defaultBaseURL)
return &Client{
apiKey: apiKey,
secretKey: secretKey,
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil
} }
// GetDomainInformation Get domain name information for a FQDN. // GetDomainInformation Get domain name information for a FQDN.
func (c *Client) GetDomainInformation(fqdn string) (*Data, error) { func (c *Client) GetDomainInformation(ctx context.Context, fqdn string) (*Data, error) {
authZone, err := dns01.FindZoneByFqdn(fqdn) endpoint := c.baseURL.JoinPath("domain")
req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
result, err := c.doRequest(http.MethodGet, "domain", nil) authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("cloudflare: could not find zone for FQDN %q: %w", fqdn, err)
} }
var domains []Data var domains []Data
if len(result) > 0 { err = c.do(req, &domains)
err = json.Unmarshal(result, &domains)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: domains unmarshaling error: %w", err) return nil, err
}
} }
for _, data := range domains { for _, data := range domains {
@ -94,20 +75,28 @@ func (c *Client) GetDomainInformation(fqdn string) (*Data, error) {
} }
} }
return nil, fmt.Errorf("CloudXNS: zone %s not found for domain %s", authZone, fqdn) return nil, fmt.Errorf("zone %s not found for domain %s", authZone, fqdn)
} }
// FindTxtRecord return the TXT record a zone ID and a FQDN. // FindTxtRecord return the TXT record a zone ID and a FQDN.
func (c *Client) FindTxtRecord(zoneID, fqdn string) (*TXTRecord, error) { func (c *Client) FindTxtRecord(ctx context.Context, zoneID, fqdn string) (*TXTRecord, error) {
result, err := c.doRequest(http.MethodGet, fmt.Sprintf("record/%s?host_id=0&offset=0&row_num=2000", zoneID), nil) endpoint := c.baseURL.JoinPath("record", zoneID)
query := endpoint.Query()
query.Set("host_id", "0")
query.Set("offset", "0")
query.Set("row_num", "2000")
endpoint.RawQuery = query.Encode()
req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var records []TXTRecord var records []TXTRecord
err = json.Unmarshal(result, &records) err = c.do(req, &records)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: TXT record unmarshaling error: %w", err) return nil, err
} }
for _, record := range records { for _, record := range records {
@ -116,22 +105,24 @@ func (c *Client) FindTxtRecord(zoneID, fqdn string) (*TXTRecord, error) {
} }
} }
return nil, fmt.Errorf("CloudXNS: no existing record found for %q", fqdn) return nil, fmt.Errorf("no existing record found for %q", fqdn)
} }
// AddTxtRecord add a TXT record. // AddTxtRecord add a TXT record.
func (c *Client) AddTxtRecord(info *Data, fqdn, value string, ttl int) error { func (c *Client) AddTxtRecord(ctx context.Context, info *Data, fqdn, value string, ttl int) error {
id, err := strconv.Atoi(info.ID) id, err := strconv.Atoi(info.ID)
if err != nil { if err != nil {
return fmt.Errorf("CloudXNS: invalid zone ID: %w", err) return fmt.Errorf("invalid zone ID: %w", err)
} }
endpoint := c.baseURL.JoinPath("record")
subDomain, err := dns01.ExtractSubDomain(fqdn, info.Domain) subDomain, err := dns01.ExtractSubDomain(fqdn, info.Domain)
if err != nil { if err != nil {
return fmt.Errorf("CloudXNS: %w", err) return err
} }
payload := TXTRecord{ record := TXTRecord{
ID: id, ID: id,
Host: subDomain, Host: subDomain,
Value: value, Value: value,
@ -140,74 +131,91 @@ func (c *Client) AddTxtRecord(info *Data, fqdn, value string, ttl int) error {
TTL: ttl, TTL: ttl,
} }
body, err := json.Marshal(payload) req, err := c.newRequest(ctx, http.MethodPost, endpoint, record)
if err != nil { if err != nil {
return fmt.Errorf("CloudXNS: record unmarshaling error: %w", err) return err
} }
_, err = c.doRequest(http.MethodPost, "record", body) return c.do(req, nil)
return err
} }
// RemoveTxtRecord remove a TXT record. // RemoveTxtRecord remove a TXT record.
func (c *Client) RemoveTxtRecord(recordID, zoneID string) error { func (c *Client) RemoveTxtRecord(ctx context.Context, recordID, zoneID string) error {
_, err := c.doRequest(http.MethodDelete, fmt.Sprintf("record/%s/%s", recordID, zoneID), nil) endpoint := c.baseURL.JoinPath("record", recordID, zoneID)
return err
}
func (c *Client) doRequest(method, uri string, body []byte) (json.RawMessage, error) { req, err := c.newRequest(ctx, http.MethodDelete, endpoint, nil)
req, err := c.buildRequest(method, uri, body)
if err != nil { if err != nil {
return nil, err return err
} }
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.HTTPClient.Do(req) resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: %w", err) return errutils.NewHTTPDoError(req, err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
content, err := io.ReadAll(resp.Body) raw, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: %s", toUnreadableBodyMessage(req, content)) return errutils.NewReadResponseError(req, resp.StatusCode, err)
} }
var r apiResponse var response apiResponse
err = json.Unmarshal(content, &r) err = json.Unmarshal(raw, &response)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: response unmashaling error: %w: %s", err, toUnreadableBodyMessage(req, content)) return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
} }
if r.Code != 1 { if response.Code != 1 {
return nil, fmt.Errorf("CloudXNS: invalid code (%v), error: %s", r.Code, r.Message) return fmt.Errorf("[status code %d] invalid code (%v) error: %s", resp.StatusCode, response.Code, response.Message)
} }
return r.Data, nil
if result == nil {
return nil
}
if len(response.Data) == 0 {
return nil
}
err = json.Unmarshal(response.Data, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
} }
func (c *Client) buildRequest(method, uri string, body []byte) (*http.Request, error) { func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
url := c.BaseURL + uri buf := new(bytes.Buffer)
req, err := http.NewRequest(method, url, bytes.NewReader(body)) if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil { if err != nil {
return nil, fmt.Errorf("CloudXNS: invalid request: %w", err) return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
} }
requestDate := time.Now().Format(time.RFC1123Z) requestDate := time.Now().Format(time.RFC1123Z)
req.Header.Set("API-KEY", c.apiKey) req.Header.Set("API-KEY", c.apiKey)
req.Header.Set("API-REQUEST-DATE", requestDate) req.Header.Set("API-REQUEST-DATE", requestDate)
req.Header.Set("API-HMAC", c.hmac(url, requestDate, string(body))) req.Header.Set("API-HMAC", c.hmac(endpoint.String(), requestDate, buf.String()))
req.Header.Set("API-FORMAT", "json") req.Header.Set("API-FORMAT", "json")
return req, nil return req, nil
} }
func (c *Client) hmac(url, date, body string) string { func (c *Client) hmac(endpoint, date, body string) string {
sum := md5.Sum([]byte(c.apiKey + url + body + date + c.secretKey)) sum := md5.Sum([]byte(c.apiKey + endpoint + body + date + c.secretKey))
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
} }
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
}

View file

@ -1,19 +1,35 @@
package internal package internal
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func handlerMock(method string, response *apiResponse, data interface{}) http.Handler { func setupTest(t *testing.T, handler http.HandlerFunc) *Client {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client, _ := NewClient("myKey", "mySecret")
client.baseURL, _ = url.Parse(server.URL + "/")
client.HTTPClient = server.Client()
return client
}
func handlerMock(method string, response *apiResponse, data interface{}) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method { if req.Method != method {
content, err := json.Marshal(apiResponse{ content, err := json.Marshal(apiResponse{
Code: 999, // random code only for the test Code: 999, // random code only for the test
@ -47,10 +63,10 @@ func handlerMock(method string, response *apiResponse, data interface{}) http.Ha
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }
}) }
} }
func TestClientGetDomainInformation(t *testing.T) { func TestClient_GetDomainInformation(t *testing.T) {
type result struct { type result struct {
domain *Data domain *Data
error bool error bool
@ -106,13 +122,9 @@ func TestClientGetDomainInformation(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, test.response, test.data)) client := setupTest(t, handlerMock(http.MethodGet, test.response, test.data))
t.Cleanup(server.Close)
client, _ := NewClient("myKey", "mySecret") domain, err := client.GetDomainInformation(context.Background(), test.fqdn)
client.BaseURL = server.URL + "/"
domain, err := client.GetDomainInformation(test.fqdn)
if test.expected.error { if test.expected.error {
require.Error(t, err) require.Error(t, err)
@ -124,7 +136,7 @@ func TestClientGetDomainInformation(t *testing.T) {
} }
} }
func TestClientFindTxtRecord(t *testing.T) { func TestClient_FindTxtRecord(t *testing.T) {
type result struct { type result struct {
txtRecord *TXTRecord txtRecord *TXTRecord
error bool error bool
@ -210,13 +222,9 @@ func TestClientFindTxtRecord(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
server := httptest.NewServer(handlerMock(http.MethodGet, test.response, test.txtRecords)) client := setupTest(t, handlerMock(http.MethodGet, test.response, test.txtRecords))
t.Cleanup(server.Close)
client, _ := NewClient("myKey", "mySecret") txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneID, test.fqdn)
client.BaseURL = server.URL + "/"
txtRecord, err := client.FindTxtRecord(test.zoneID, test.fqdn)
if test.expected.error { if test.expected.error {
require.Error(t, err) require.Error(t, err)
@ -228,7 +236,7 @@ func TestClientFindTxtRecord(t *testing.T) {
} }
} }
func TestClientAddTxtRecord(t *testing.T) { func TestClient_AddTxtRecord(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
domain *Data domain *Data
@ -267,21 +275,17 @@ func TestClientAddTxtRecord(t *testing.T) {
Code: 1, Code: 1,
} }
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) {
assert.NotNil(t, req.Body) assert.NotNil(t, req.Body)
content, err := io.ReadAll(req.Body) content, err := io.ReadAll(req.Body)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expected, string(content)) assert.Equal(t, test.expected, string(bytes.TrimSpace(content)))
handlerMock(http.MethodPost, response, nil).ServeHTTP(rw, req) handlerMock(http.MethodPost, response, nil).ServeHTTP(rw, req)
})) })
t.Cleanup(server.Close)
client, _ := NewClient("myKey", "mySecret") err := client.AddTxtRecord(context.Background(), test.domain, test.fqdn, test.value, test.ttl)
client.BaseURL = server.URL + "/"
err := client.AddTxtRecord(test.domain, test.fqdn, test.value, test.ttl)
require.NoError(t, err) require.NoError(t, err)
}) })
} }

View file

@ -0,0 +1,28 @@
package internal
import "encoding/json"
type apiResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
// Data Domain information.
type Data struct {
ID string `json:"id"`
Domain string `json:"domain"`
TTL int `json:"ttl,omitempty"`
}
// TXTRecord a TXT record.
type TXTRecord struct {
ID int `json:"domain_id,omitempty"`
RecordID string `json:"record_id,omitempty"`
Host string `json:"host"`
Value string `json:"value"`
Type string `json:"type"`
LineID int `json:"line_id,string"`
TTL int `json:"ttl,string"`
}

View file

@ -2,6 +2,7 @@
package conoha package conoha
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -85,6 +86,15 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("conoha: some credentials information are missing") return nil, errors.New("conoha: some credentials information are missing")
} }
identifier, err := internal.NewIdentifier(config.Region)
if err != nil {
return nil, fmt.Errorf("conoha: failed to create identity client: %w", err)
}
if config.HTTPClient != nil {
identifier.HTTPClient = config.HTTPClient
}
auth := internal.Auth{ auth := internal.Auth{
TenantID: config.TenantID, TenantID: config.TenantID,
PasswordCredentials: internal.PasswordCredentials{ PasswordCredentials: internal.PasswordCredentials{
@ -93,11 +103,20 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
}, },
} }
client, err := internal.NewClient(config.Region, auth, config.HTTPClient) tokens, err := identifier.GetToken(context.TODO(), auth)
if err != nil {
return nil, fmt.Errorf("conoha: failed to login: %w", err)
}
client, err := internal.NewClient(config.Region, tokens.Access.Token.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("conoha: failed to create client: %w", err) return nil, fmt.Errorf("conoha: failed to create client: %w", err)
} }
if config.HTTPClient != nil {
client.HTTPClient = config.HTTPClient
}
return &DNSProvider{config: config, client: client}, nil return &DNSProvider{config: config, client: client}, nil
} }
@ -107,10 +126,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("conoha: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
id, err := d.client.GetDomainID(authZone) ctx := context.Background()
id, err := d.client.GetDomainID(ctx, authZone)
if err != nil { if err != nil {
return fmt.Errorf("conoha: failed to get domain ID: %w", err) return fmt.Errorf("conoha: failed to get domain ID: %w", err)
} }
@ -122,7 +143,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
TTL: d.config.TTL, TTL: d.config.TTL,
} }
err = d.client.CreateRecord(id, record) err = d.client.CreateRecord(ctx, id, record)
if err != nil { if err != nil {
return fmt.Errorf("conoha: failed to create record: %w", err) return fmt.Errorf("conoha: failed to create record: %w", err)
} }
@ -136,20 +157,22 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("conoha: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
domID, err := d.client.GetDomainID(authZone) ctx := context.Background()
domID, err := d.client.GetDomainID(ctx, authZone)
if err != nil { if err != nil {
return fmt.Errorf("conoha: failed to get domain ID: %w", err) return fmt.Errorf("conoha: failed to get domain ID: %w", err)
} }
recID, err := d.client.GetRecordID(domID, info.EffectiveFQDN, "TXT", info.Value) recID, err := d.client.GetRecordID(ctx, domID, info.EffectiveFQDN, "TXT", info.Value)
if err != nil { if err != nil {
return fmt.Errorf("conoha: failed to get record ID: %w", err) return fmt.Errorf("conoha: failed to get record ID: %w", err)
} }
err = d.client.DeleteRecord(domID, recID) err = d.client.DeleteRecord(ctx, domID, recID)
if err != nil { if err != nil {
return fmt.Errorf("conoha: failed to delete record: %w", err) return fmt.Errorf("conoha: failed to delete record: %w", err)
} }

View file

@ -29,7 +29,7 @@ func TestNewDNSProvider(t *testing.T) {
EnvAPIUsername: "api_username", EnvAPIUsername: "api_username",
EnvAPIPassword: "api_password", EnvAPIPassword: "api_password",
}, },
expected: `conoha: failed to create client: failed to login: HTTP request failed with status code 401: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`, expected: `conoha: failed to login: unexpected status code: [status code: 401] body: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`,
}, },
{ {
desc: "missing credentials", desc: "missing credentials",
@ -99,7 +99,7 @@ func TestNewDNSProviderConfig(t *testing.T) {
}{ }{
{ {
desc: "complete credentials, but login failed", desc: "complete credentials, but login failed",
expected: `conoha: failed to create client: failed to login: HTTP request failed with status code 401: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`, expected: `conoha: failed to login: unexpected status code: [status code: 401] body: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`,
tenant: "tenant_id", tenant: "tenant_id",
username: "api_username", username: "api_username",
password: "api_password", password: "api_password",

View file

@ -2,121 +2,45 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const ( const dnsServiceBaseURL = "https://dns-service.%s.conoha.io"
identityBaseURL = "https://identity.%s.conoha.io"
dnsServiceBaseURL = "https://dns-service.%s.conoha.io"
)
// IdentityRequest is an authentication request body.
type IdentityRequest struct {
Auth Auth `json:"auth"`
}
// Auth is an authentication information.
type Auth struct {
TenantID string `json:"tenantId"`
PasswordCredentials PasswordCredentials `json:"passwordCredentials"`
}
// PasswordCredentials is API-user's credentials.
type PasswordCredentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
// IdentityResponse is an authentication response body.
type IdentityResponse struct {
Access Access `json:"access"`
}
// Access is an identity information.
type Access struct {
Token Token `json:"token"`
}
// Token is an api access token.
type Token struct {
ID string `json:"id"`
}
// DomainListResponse is a response of a domain listing request.
type DomainListResponse struct {
Domains []Domain `json:"domains"`
}
// Domain is a hosted domain entry.
type Domain struct {
ID string `json:"id"`
Name string `json:"name"`
}
// RecordListResponse is a response of record listing request.
type RecordListResponse struct {
Records []Record `json:"records"`
}
// Record is a record entry.
type Record struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Type string `json:"type"`
Data string `json:"data"`
TTL int `json:"ttl"`
}
// Client is a ConoHa API client. // Client is a ConoHa API client.
type Client struct { type Client struct {
token string token string
endpoint string
httpClient *http.Client baseURL *url.URL
HTTPClient *http.Client
} }
// NewClient returns a client instance logged into the ConoHa service. // NewClient returns a client instance logged into the ConoHa service.
func NewClient(region string, auth Auth, httpClient *http.Client) (*Client, error) { func NewClient(region string, token string) (*Client, error) {
if httpClient == nil { baseURL, err := url.Parse(fmt.Sprintf(dnsServiceBaseURL, region))
httpClient = &http.Client{}
}
c := &Client{httpClient: httpClient}
c.endpoint = fmt.Sprintf(identityBaseURL, region)
identity, err := c.getIdentity(auth)
if err != nil {
return nil, fmt.Errorf("failed to login: %w", err)
}
c.token = identity.Access.Token.ID
c.endpoint = fmt.Sprintf(dnsServiceBaseURL, region)
return c, nil
}
func (c *Client) getIdentity(auth Auth) (*IdentityResponse, error) {
req := &IdentityRequest{Auth: auth}
identity := &IdentityResponse{}
err := c.do(http.MethodPost, "/v2.0/tokens", req, identity)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return identity, nil return &Client{
token: token,
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}, nil
} }
// GetDomainID returns an ID of specified domain. // GetDomainID returns an ID of specified domain.
func (c *Client) GetDomainID(domainName string) (string, error) { func (c *Client) GetDomainID(ctx context.Context, domainName string) (string, error) {
domainList := &DomainListResponse{} domainList, err := c.getDomains(ctx)
err := c.do(http.MethodGet, "/v1/domains", nil, domainList)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -126,14 +50,32 @@ func (c *Client) GetDomainID(domainName string) (string, error) {
return domain.ID, nil return domain.ID, nil
} }
} }
return "", fmt.Errorf("no such domain: %s", domainName) return "", fmt.Errorf("no such domain: %s", domainName)
} }
// GetRecordID returns an ID of specified record. // https://www.conoha.jp/docs/paas-dns-list-domains.php
func (c *Client) GetRecordID(domainID, recordName, recordType, data string) (string, error) { func (c *Client) getDomains(ctx context.Context) (*DomainListResponse, error) {
recordList := &RecordListResponse{} endpoint := c.baseURL.JoinPath("v1", "domains")
err := c.do(http.MethodGet, fmt.Sprintf("/v1/domains/%s/records", domainID), nil, recordList) req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
domainList := &DomainListResponse{}
err = c.do(req, domainList)
if err != nil {
return nil, err
}
return domainList, nil
}
// GetRecordID returns an ID of specified record.
func (c *Client) GetRecordID(ctx context.Context, domainID, recordName, recordType, data string) (string, error) {
recordList, err := c.getRecords(ctx, domainID)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -143,63 +85,119 @@ func (c *Client) GetRecordID(domainID, recordName, recordType, data string) (str
return record.ID, nil return record.ID, nil
} }
} }
return "", errors.New("no such record") return "", errors.New("no such record")
} }
// https://www.conoha.jp/docs/paas-dns-list-records-in-a-domain.php
func (c *Client) getRecords(ctx context.Context, domainID string) (*RecordListResponse, error) {
endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
recordList := &RecordListResponse{}
err = c.do(req, recordList)
if err != nil {
return nil, err
}
return recordList, nil
}
// CreateRecord adds new record. // CreateRecord adds new record.
func (c *Client) CreateRecord(domainID string, record Record) error { func (c *Client) CreateRecord(ctx context.Context, domainID string, record Record) error {
return c.do(http.MethodPost, fmt.Sprintf("/v1/domains/%s/records", domainID), record, nil) _, err := c.createRecord(ctx, domainID, record)
return err
}
// https://www.conoha.jp/docs/paas-dns-create-record.php
func (c *Client) createRecord(ctx context.Context, domainID string, record Record) (*Record, error) {
endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return nil, err
}
newRecord := &Record{}
err = c.do(req, newRecord)
if err != nil {
return nil, err
}
return newRecord, nil
} }
// DeleteRecord removes specified record. // DeleteRecord removes specified record.
func (c *Client) DeleteRecord(domainID, recordID string) error { // https://www.conoha.jp/docs/paas-dns-delete-a-record.php
return c.do(http.MethodDelete, fmt.Sprintf("/v1/domains/%s/records/%s", domainID, recordID), nil, nil) func (c *Client) DeleteRecord(ctx context.Context, domainID, recordID string) error {
endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records", recordID)
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
return c.do(req, nil)
} }
func (c *Client) do(method, path string, payload, result interface{}) error { func (c *Client) do(req *http.Request, result any) error {
body := bytes.NewReader(nil) if c.token != "" {
if payload != nil {
bodyBytes, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(bodyBytes)
}
req, err := http.NewRequest(method, c.endpoint+path, body)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Auth-Token", c.token) req.Header.Set("X-Auth-Token", c.token)
resp, err := c.httpClient.Do(req)
if err != nil {
return err
} }
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
respBody, err := io.ReadAll(resp.Body) return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
defer resp.Body.Close()
return fmt.Errorf("HTTP request failed with status code %d: %s", resp.StatusCode, string(respBody))
} }
if result != nil { err = json.Unmarshal(raw, result)
respBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
defer resp.Body.Close()
return json.Unmarshal(respBody, result)
} }
return nil return nil
} }
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View file

@ -1,30 +1,71 @@
package internal package internal
import ( import (
"bytes"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func setupTest(t *testing.T) (*http.ServeMux, *Client) { func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper() t.Helper()
mux := http.NewServeMux() mux := http.NewServeMux()
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
t.Cleanup(server.Close) t.Cleanup(server.Close)
client := &Client{ client, err := NewClient("tyo1", "secret")
token: "secret", require.NoError(t, err)
endpoint: server.URL,
httpClient: server.Client(), client.HTTPClient = server.Client()
client.baseURL, _ = url.Parse(server.URL)
return client, mux
}
func writeFixtureHandler(method, filename string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
} }
return mux, client writeFixture(rw, filename)
}
}
func writeBodyHandler(method, content string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
_, err := fmt.Fprint(rw, content)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}
}
func writeFixture(rw http.ResponseWriter, filename string) {
file, err := os.Open(filepath.Join("fixtures", filename))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = file.Close() }()
_, _ = io.Copy(rw, file)
} }
func TestClient_GetDomainID(t *testing.T) { func TestClient_GetDomainID(t *testing.T) {
@ -42,91 +83,30 @@ func TestClient_GetDomainID(t *testing.T) {
{ {
desc: "success", desc: "success",
domainName: "domain1.com.", domainName: "domain1.com.",
handler: func(rw http.ResponseWriter, req *http.Request) { handler: writeFixtureHandler(http.MethodGet, "domains_GET.json"),
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
return
}
content := `
{
"domains":[
{
"id": "09494b72-b65b-4297-9efb-187f65a0553e",
"name": "domain1.com.",
"ttl": 3600,
"serial": 1351800668,
"email": "nsadmin@example.org",
"gslb": 0,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": null,
"description": "memo"
},
{
"id": "cf661142-e577-40b5-b3eb-75795cdc0cd7",
"name": "domain2.com.",
"ttl": 7200,
"serial": 1351800670,
"email": "nsadmin2@example.org",
"gslb": 1,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": "2012-12-01T20:11:08.000000",
"description": "memomemo"
}
]
}
`
_, err := fmt.Fprint(rw, content)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
},
expected: expected{domainID: "09494b72-b65b-4297-9efb-187f65a0553e"}, expected: expected{domainID: "09494b72-b65b-4297-9efb-187f65a0553e"},
}, },
{ {
desc: "non existing domain", desc: "non existing domain",
domainName: "domain1.com.", domainName: "domain1.com.",
handler: func(rw http.ResponseWriter, req *http.Request) { handler: writeBodyHandler(http.MethodGet, "{}"),
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
return
}
_, err := fmt.Fprint(rw, "{}")
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
},
expected: expected{error: true}, expected: expected{error: true},
}, },
{ {
desc: "marshaling error", desc: "marshaling error",
domainName: "domain1.com.", domainName: "domain1.com.",
handler: func(rw http.ResponseWriter, req *http.Request) { handler: writeBodyHandler(http.MethodGet, "[]"),
if req.Method != http.MethodGet {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
return
}
_, err := fmt.Fprint(rw, "[]")
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
},
expected: expected{error: true}, expected: expected{error: true},
}, },
} }
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
mux, client := setupTest(t) client, mux := setupTest(t)
mux.Handle("/v1/domains", test.handler) mux.Handle("/v1/domains", test.handler)
domainID, err := client.GetDomainID(test.domainName) domainID, err := client.GetDomainID(context.Background(), test.domainName)
if test.expected.error { if test.expected.error {
require.Error(t, err) require.Error(t, err)
@ -142,13 +122,13 @@ func TestClient_CreateRecord(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
handler http.HandlerFunc handler http.HandlerFunc
expectError bool assert require.ErrorAssertionFunc
}{ }{
{ {
desc: "success", desc: "success",
handler: func(rw http.ResponseWriter, req *http.Request) { handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost { if req.Method != http.MethodPost {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return return
} }
@ -157,31 +137,34 @@ func TestClient_CreateRecord(t *testing.T) {
http.Error(rw, err.Error(), http.StatusBadRequest) http.Error(rw, err.Error(), http.StatusBadRequest)
return return
} }
defer req.Body.Close() defer func() { _ = req.Body.Close() }()
if string(raw) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` { if string(bytes.TrimSpace(raw)) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` {
http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest)
return return
} }
writeFixture(rw, "domains-records_POST.json")
}, },
assert: require.NoError,
}, },
{ {
desc: "bad request", desc: "bad request",
handler: func(rw http.ResponseWriter, req *http.Request) { handler: func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost { if req.Method != http.MethodPost {
http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return return
} }
http.Error(rw, "OOPS", http.StatusBadRequest) http.Error(rw, "OOPS", http.StatusBadRequest)
}, },
expectError: true, assert: require.Error,
}, },
} }
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
mux, client := setupTest(t) client, mux := setupTest(t)
mux.Handle("/v1/domains/lego/records", test.handler) mux.Handle("/v1/domains/lego/records", test.handler)
@ -194,13 +177,36 @@ func TestClient_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
err := client.CreateRecord(domainID, record) err := client.CreateRecord(context.Background(), domainID, record)
test.assert(t, err)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}) })
} }
} }
func TestClient_GetRecordID(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/v1/domains/89acac79-38e7-497d-807c-a011e1310438/records",
writeFixtureHandler(http.MethodGet, "domains-records_GET.json"))
recordID, err := client.GetRecordID(context.Background(), "89acac79-38e7-497d-807c-a011e1310438", "www.example.com.", "A", "15.185.172.153")
require.NoError(t, err)
assert.Equal(t, "2e32e609-3a4f-45ba-bdef-e50eacd345ad", recordID)
}
func TestClient_DeleteRecord(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/v1/domains/89acac79-38e7-497d-807c-a011e1310438/records/2e32e609-3a4f-45ba-bdef-e50eacd345ad", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodDelete {
http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest)
return
}
rw.WriteHeader(http.StatusOK)
})
err := client.DeleteRecord(context.Background(), "89acac79-38e7-497d-807c-a011e1310438", "2e32e609-3a4f-45ba-bdef-e50eacd345ad")
require.NoError(t, err)
}

View file

@ -0,0 +1,43 @@
{
"records": [
{
"id": "2e32e609-3a4f-45ba-bdef-e50eacd345ad",
"name": "www.example.com.",
"type": "A",
"ttl": 3600,
"created_at": "2012-11-02T19:56:26.000000",
"updated_at": "2012-11-04T13:22:36.000000",
"data": "15.185.172.153",
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"version": 1,
"gslb_region": "JP",
"gslb_weight": 250,
"gslb_check": 12300
},
{
"id": "8e9ecf3e-fb92-4a3a-a8ae-7596f167bea3",
"name": "host1.example.com.",
"type": "A",
"ttl": 3600,
"created_at": "2012-11-04T13:57:50.000000",
"updated_at": null,
"data": "15.185.172.154",
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"version": 1,
"gslb_region": "US",
"gslb_weight": 220,
"gslb_check": 12200
},
{
"id": "4ad19089-3e62-40f8-9482-17cc8ccb92cb",
"name": "web.example.com.",
"type": "CNAME",
"ttl": 3600,
"created_at": "2012-11-04T13:58:16.393735",
"updated_at": null,
"data": "www.example.com.",
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"version": 1
}
]
}

View file

@ -0,0 +1,13 @@
{
"id": "2e32e609-3a4f-45ba-bdef-e50eacd345ad",
"name": "www.example.com.",
"type": "A",
"created_at": "2012-11-02T19:56:26.366792",
"updated_at": null,
"domain_id": "89acac79-38e7-497d-807c-a011e1310438",
"ttl": null,
"data": "192.0.2.3",
"gslb_check": 1,
"gslb_region": "JP",
"gslb_weight": 250
}

View file

@ -0,0 +1,26 @@
{
"domains":[
{
"id": "09494b72-b65b-4297-9efb-187f65a0553e",
"name": "domain1.com.",
"ttl": 3600,
"serial": 1351800668,
"email": "nsadmin@example.org",
"gslb": 0,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": null,
"description": "memo"
},
{
"id": "cf661142-e577-40b5-b3eb-75795cdc0cd7",
"name": "domain2.com.",
"ttl": 7200,
"serial": 1351800670,
"email": "nsadmin2@example.org",
"gslb": 1,
"created_at": "2012-11-01T20:11:08.000000",
"updated_at": "2012-12-01T20:11:08.000000",
"description": "memomemo"
}
]
}

View file

@ -0,0 +1,17 @@
{
"access": {
"token": {
"issued_at": "2015-05-19T07:08:21.927295",
"expires": "2015-05-20T07:08:21Z",
"id": "sample00d88246078f2bexample788f7",
"tenant": {
"name": "example00000000",
"enabled": true,
"tyo1_image_size": "550GB"
},
"endpoints_links": [],
"type": "mailhosting",
"name": "Mail Hosting Service"
}
}
}

View file

@ -0,0 +1,82 @@
package internal
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
)
const identityBaseURL = "https://identity.%s.conoha.io"
type Identifier struct {
baseURL *url.URL
HTTPClient *http.Client
}
// NewIdentifier creates a new Identifier.
func NewIdentifier(region string) (*Identifier, error) {
baseURL, err := url.Parse(fmt.Sprintf(identityBaseURL, region))
if err != nil {
return nil, err
}
return &Identifier{
baseURL: baseURL,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}, nil
}
// GetToken gets valid token information.
// https://www.conoha.jp/docs/identity-post_tokens.php
func (c *Identifier) GetToken(ctx context.Context, auth Auth) (*IdentityResponse, error) {
endpoint := c.baseURL.JoinPath("v2.0", "tokens")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, &IdentityRequest{Auth: auth})
if err != nil {
return nil, err
}
identity := &IdentityResponse{}
err = c.do(req, identity)
if err != nil {
return nil, err
}
return identity, nil
}
func (c *Identifier) do(req *http.Request, result any) error {
resp, err := c.HTTPClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}

View file

@ -0,0 +1,41 @@
package internal
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewClient(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
identifier, err := NewIdentifier("tyo1")
require.NoError(t, err)
identifier.HTTPClient = server.Client()
identifier.baseURL, _ = url.Parse(server.URL)
mux.HandleFunc("/v2.0/tokens", writeFixtureHandler(http.MethodPost, "tokens_POST.json"))
auth := Auth{
TenantID: "487727e3921d44e3bfe7ebb337bf085e",
PasswordCredentials: PasswordCredentials{
Username: "ConoHa",
Password: "paSSword123456#$%",
},
}
token, err := identifier.GetToken(context.Background(), auth)
require.NoError(t, err)
expected := &IdentityResponse{Access: Access{Token: Token{ID: "sample00d88246078f2bexample788f7"}}}
assert.Equal(t, expected, token)
}

View file

@ -0,0 +1,58 @@
package internal
// IdentityRequest is an authentication request body.
type IdentityRequest struct {
Auth Auth `json:"auth"`
}
// Auth is an authentication information.
type Auth struct {
TenantID string `json:"tenantId"`
PasswordCredentials PasswordCredentials `json:"passwordCredentials"`
}
// PasswordCredentials is API-user's credentials.
type PasswordCredentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
// IdentityResponse is an authentication response body.
type IdentityResponse struct {
Access Access `json:"access"`
}
// Access is an identity information.
type Access struct {
Token Token `json:"token"`
}
// Token is an api access token.
type Token struct {
ID string `json:"id"`
}
// DomainListResponse is a response of a domain listing request.
type DomainListResponse struct {
Domains []Domain `json:"domains"`
}
// Domain is a hosted domain entry.
type Domain struct {
ID string `json:"id"`
Name string `json:"name"`
}
// RecordListResponse is a response of record listing request.
type RecordListResponse struct {
Records []Record `json:"records"`
}
// Record is a record entry.
type Record struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Type string `json:"type"`
Data string `json:"data"`
TTL int `json:"ttl"`
}

View file

@ -2,6 +2,7 @@
package constellix package constellix
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -101,10 +102,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("constellix: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) return fmt.Errorf("constellix: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
dom, err := d.client.Domains.GetByName(dns01.UnFqdn(authZone)) ctx := context.Background()
dom, err := d.client.Domains.GetByName(ctx, dns01.UnFqdn(authZone))
if err != nil { if err != nil {
return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err) return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err)
} }
@ -114,7 +117,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
return fmt.Errorf("constellix: %w", err) return fmt.Errorf("constellix: %w", err)
} }
records, err := d.client.TxtRecords.Search(dom.ID, internal.Exact, recordName) records, err := d.client.TxtRecords.Search(ctx, dom.ID, internal.Exact, recordName)
if err != nil { if err != nil {
return fmt.Errorf("constellix: failed to search TXT records: %w", err) return fmt.Errorf("constellix: failed to search TXT records: %w", err)
} }
@ -125,10 +128,10 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
// TXT record entry already existing // TXT record entry already existing
if len(records) == 1 { if len(records) == 1 {
return d.appendRecordValue(dom, records[0].ID, info.Value) return d.appendRecordValue(ctx, dom, records[0].ID, info.Value)
} }
err = d.createRecord(dom, info.EffectiveFQDN, recordName, info.Value) err = d.createRecord(ctx, dom, info.EffectiveFQDN, recordName, info.Value)
if err != nil { if err != nil {
return fmt.Errorf("constellix: %w", err) return fmt.Errorf("constellix: %w", err)
} }
@ -142,10 +145,12 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("constellix: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) return fmt.Errorf("constellix: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
dom, err := d.client.Domains.GetByName(dns01.UnFqdn(authZone)) ctx := context.Background()
dom, err := d.client.Domains.GetByName(ctx, dns01.UnFqdn(authZone))
if err != nil { if err != nil {
return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err) return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err)
} }
@ -155,7 +160,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("constellix: %w", err) return fmt.Errorf("constellix: %w", err)
} }
records, err := d.client.TxtRecords.Search(dom.ID, internal.Exact, recordName) records, err := d.client.TxtRecords.Search(ctx, dom.ID, internal.Exact, recordName)
if err != nil { if err != nil {
return fmt.Errorf("constellix: failed to search TXT records: %w", err) return fmt.Errorf("constellix: failed to search TXT records: %w", err)
} }
@ -168,7 +173,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil return nil
} }
record, err := d.client.TxtRecords.Get(dom.ID, records[0].ID) record, err := d.client.TxtRecords.Get(ctx, dom.ID, records[0].ID)
if err != nil { if err != nil {
return fmt.Errorf("constellix: failed to get TXT records: %w", err) return fmt.Errorf("constellix: failed to get TXT records: %w", err)
} }
@ -179,14 +184,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
// only 1 record value, the whole record must be deleted. // only 1 record value, the whole record must be deleted.
if len(record.Value) == 1 { if len(record.Value) == 1 {
_, err = d.client.TxtRecords.Delete(dom.ID, record.ID) _, err = d.client.TxtRecords.Delete(ctx, dom.ID, record.ID)
if err != nil { if err != nil {
return fmt.Errorf("constellix: failed to delete TXT records: %w", err) return fmt.Errorf("constellix: failed to delete TXT records: %w", err)
} }
return nil return nil
} }
err = d.removeRecordValue(dom, record, info.Value) err = d.removeRecordValue(ctx, dom, record, info.Value)
if err != nil { if err != nil {
return fmt.Errorf("constellix: %w", err) return fmt.Errorf("constellix: %w", err)
} }
@ -194,7 +199,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil return nil
} }
func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value string) error { func (d *DNSProvider) createRecord(ctx context.Context, dom internal.Domain, fqdn, recordName, value string) error {
request := internal.RecordRequest{ request := internal.RecordRequest{
Name: recordName, Name: recordName,
TTL: d.config.TTL, TTL: d.config.TTL,
@ -203,7 +208,7 @@ func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value
}, },
} }
_, err := d.client.TxtRecords.Create(dom.ID, request) _, err := d.client.TxtRecords.Create(ctx, dom.ID, request)
if err != nil { if err != nil {
return fmt.Errorf("failed to create TXT record %s: %w", fqdn, err) return fmt.Errorf("failed to create TXT record %s: %w", fqdn, err)
} }
@ -211,8 +216,8 @@ func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value
return nil return nil
} }
func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, value string) error { func (d *DNSProvider) appendRecordValue(ctx context.Context, dom internal.Domain, recordID int64, value string) error {
record, err := d.client.TxtRecords.Get(dom.ID, recordID) record, err := d.client.TxtRecords.Get(ctx, dom.ID, recordID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get TXT records: %w", err) return fmt.Errorf("failed to get TXT records: %w", err)
} }
@ -227,7 +232,7 @@ func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, val
RoundRobin: append(record.RoundRobin, internal.RecordValue{Value: fmt.Sprintf(`%q`, value)}), RoundRobin: append(record.RoundRobin, internal.RecordValue{Value: fmt.Sprintf(`%q`, value)}),
} }
_, err = d.client.TxtRecords.Update(dom.ID, record.ID, request) _, err = d.client.TxtRecords.Update(ctx, dom.ID, record.ID, request)
if err != nil { if err != nil {
return fmt.Errorf("failed to update TXT records: %w", err) return fmt.Errorf("failed to update TXT records: %w", err)
} }
@ -235,7 +240,7 @@ func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, val
return nil return nil
} }
func (d *DNSProvider) removeRecordValue(dom internal.Domain, record *internal.Record, value string) error { func (d *DNSProvider) removeRecordValue(ctx context.Context, dom internal.Domain, record *internal.Record, value string) error {
request := internal.RecordRequest{ request := internal.RecordRequest{
Name: record.Name, Name: record.Name,
TTL: record.TTL, TTL: record.TTL,
@ -247,7 +252,7 @@ func (d *DNSProvider) removeRecordValue(dom internal.Domain, record *internal.Re
} }
} }
_, err := d.client.TxtRecords.Update(dom.ID, record.ID, request) _, err := d.client.TxtRecords.Update(ctx, dom.ID, record.ID, request)
if err != nil { if err != nil {
return fmt.Errorf("failed to update TXT records: %w", err) return fmt.Errorf("failed to update TXT records: %w", err)
} }

View file

@ -6,6 +6,9 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const ( const (
@ -28,7 +31,7 @@ type Client struct {
// NewClient Creates a Constellix client. // NewClient Creates a Constellix client.
func NewClient(httpClient *http.Client) *Client { func NewClient(httpClient *http.Client) *Client {
if httpClient == nil { if httpClient == nil {
httpClient = http.DefaultClient httpClient = &http.Client{Timeout: 5 * time.Second}
} }
client := &Client{ client := &Client{
@ -48,13 +51,15 @@ type service struct {
} }
// do sends an API request and returns the API response. // do sends an API request and returns the API response.
func (c *Client) do(req *http.Request, v interface{}) error { func (c *Client) do(req *http.Request, result any) error {
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req) resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return err return errutils.NewHTTPDoError(req, err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
err = checkResponse(resp) err = checkResponse(resp)
@ -64,11 +69,11 @@ func (c *Client) do(req *http.Request, v interface{}) error {
raw, err := io.ReadAll(resp.Body) raw, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to read body: %w", err) return errutils.NewReadResponseError(req, resp.StatusCode, err)
} }
if err = json.Unmarshal(raw, v); err != nil { if err = json.Unmarshal(raw, result); err != nil {
return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
} }
return nil return nil
@ -83,21 +88,21 @@ func checkResponse(resp *http.Response) error {
return nil return nil
} }
data, err := io.ReadAll(resp.Body) raw, err := io.ReadAll(resp.Body)
if err == nil && data != nil { if err == nil && raw != nil {
msg := &APIError{StatusCode: resp.StatusCode} errAPI := &APIError{StatusCode: resp.StatusCode}
if json.Unmarshal(data, msg) != nil { if json.Unmarshal(raw, errAPI) != nil {
return fmt.Errorf("API error: status code: %d: %v", resp.StatusCode, string(data)) return fmt.Errorf("API error: status code: %d: %v", resp.StatusCode, string(raw))
} }
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusNotFound: case http.StatusNotFound:
return &NotFound{APIError: msg} return &NotFound{APIError: errAPI}
case http.StatusBadRequest: case http.StatusBadRequest:
return &BadRequest{APIError: msg} return &BadRequest{APIError: errAPI}
default: default:
return msg return errAPI
} }
} }

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -13,15 +14,15 @@ type DomainService service
// GetAll domains. // GetAll domains.
// https://api-docs.constellix.com/?version=latest#484c3f21-d724-4ee4-a6fa-ab22c8eb9e9b // https://api-docs.constellix.com/?version=latest#484c3f21-d724-4ee4-a6fa-ab22c8eb9e9b
func (s *DomainService) GetAll(params *PaginationParameters) ([]Domain, error) { func (s *DomainService) GetAll(ctx context.Context, params *PaginationParameters) ([]Domain, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains") endpoint, err := s.client.createEndpoint(defaultVersion, "domains")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
if params != nil { if params != nil {
@ -42,8 +43,8 @@ func (s *DomainService) GetAll(params *PaginationParameters) ([]Domain, error) {
} }
// GetByName Gets domain by name. // GetByName Gets domain by name.
func (s *DomainService) GetByName(domainName string) (Domain, error) { func (s *DomainService) GetByName(ctx context.Context, domainName string) (Domain, error) {
domains, err := s.Search(Exact, domainName) domains, err := s.Search(ctx, Exact, domainName)
if err != nil { if err != nil {
return Domain{}, err return Domain{}, err
} }
@ -61,15 +62,15 @@ func (s *DomainService) GetByName(domainName string) (Domain, error) {
// Search searches for a domain by name. // Search searches for a domain by name.
// https://api-docs.constellix.com/?version=latest#3d7b2679-2209-49f3-b011-b7d24e512008 // https://api-docs.constellix.com/?version=latest#3d7b2679-2209-49f3-b011-b7d24e512008
func (s *DomainService) Search(filter searchFilter, value string) ([]Domain, error) { func (s *DomainService) Search(ctx context.Context, filter searchFilter, value string) ([]Domain, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", "search") endpoint, err := s.client.createEndpoint(defaultVersion, "domains", "search")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
query := req.URL.Query() query := req.URL.Query()

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -47,7 +48,7 @@ func TestDomainService_GetAll(t *testing.T) {
} }
}) })
data, err := client.Domains.GetAll(nil) data, err := client.Domains.GetAll(context.Background(), nil)
require.NoError(t, err) require.NoError(t, err)
expected := []Domain{ expected := []Domain{
@ -83,7 +84,7 @@ func TestDomainService_Search(t *testing.T) {
} }
}) })
data, err := client.Domains.Search(Exact, "lego.wtf") data, err := client.Domains.Search(context.Background(), Exact, "lego.wtf")
require.NoError(t, err) require.NoError(t, err)
expected := []Domain{ expected := []Domain{

View file

@ -2,6 +2,7 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -14,20 +15,20 @@ type TxtRecordService service
// Create a TXT record. // Create a TXT record.
// https://api-docs.constellix.com/?version=latest#22e24d5b-9ec0-49a7-b2b0-5ff0a28e71be // https://api-docs.constellix.com/?version=latest#22e24d5b-9ec0-49a7-b2b0-5ff0a28e71be
func (s *TxtRecordService) Create(domainID int64, record RecordRequest) ([]Record, error) { func (s *TxtRecordService) Create(ctx context.Context, domainID int64, record RecordRequest) ([]Record, error) {
body, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("failed to marshall request body: %w", err)
}
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt") endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body)) body, err := json.Marshal(record)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
} }
var records []Record var records []Record
@ -41,15 +42,15 @@ func (s *TxtRecordService) Create(domainID int64, record RecordRequest) ([]Recor
// GetAll TXT records. // GetAll TXT records.
// https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2 // https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2
func (s *TxtRecordService) GetAll(domainID int64) ([]Record, error) { func (s *TxtRecordService) GetAll(ctx context.Context, domainID int64) ([]Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt") endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
var records []Record var records []Record
@ -63,15 +64,15 @@ func (s *TxtRecordService) GetAll(domainID int64) ([]Record, error) {
// Get a TXT record. // Get a TXT record.
// https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2 // https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2
func (s *TxtRecordService) Get(domainID, recordID int64) (*Record, error) { func (s *TxtRecordService) Get(ctx context.Context, domainID, recordID int64) (*Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10)) endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
var records Record var records Record
@ -85,20 +86,20 @@ func (s *TxtRecordService) Get(domainID, recordID int64) (*Record, error) {
// Update a TXT record. // Update a TXT record.
// https://api-docs.constellix.com/?version=latest#d4e9ab2e-fac0-45a6-b0e4-cf62a2d2e3da // https://api-docs.constellix.com/?version=latest#d4e9ab2e-fac0-45a6-b0e4-cf62a2d2e3da
func (s *TxtRecordService) Update(domainID, recordID int64, record RecordRequest) (*SuccessMessage, error) { func (s *TxtRecordService) Update(ctx context.Context, domainID, recordID int64, record RecordRequest) (*SuccessMessage, error) {
body, err := json.Marshal(record)
if err != nil {
return nil, fmt.Errorf("failed to marshall request body: %w", err)
}
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10)) endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewReader(body)) body, err := json.Marshal(record)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
} }
var msg SuccessMessage var msg SuccessMessage
@ -112,15 +113,15 @@ func (s *TxtRecordService) Update(domainID, recordID int64, record RecordRequest
// Delete a TXT record. // Delete a TXT record.
// https://api-docs.constellix.com/?version=latest#135947f7-d6c8-481a-83c7-4d387b0bdf9e // https://api-docs.constellix.com/?version=latest#135947f7-d6c8-481a-83c7-4d387b0bdf9e
func (s *TxtRecordService) Delete(domainID, recordID int64) (*SuccessMessage, error) { func (s *TxtRecordService) Delete(ctx context.Context, domainID, recordID int64) (*SuccessMessage, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10)) endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodDelete, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
var msg *SuccessMessage var msg *SuccessMessage
@ -134,15 +135,15 @@ func (s *TxtRecordService) Delete(domainID, recordID int64) (*SuccessMessage, er
// Search searches for a TXT record by name. // Search searches for a TXT record by name.
// https://api-docs.constellix.com/?version=latest#81003e4f-bd3f-413f-a18d-6d9d18f10201 // https://api-docs.constellix.com/?version=latest#81003e4f-bd3f-413f-a18d-6d9d18f10201
func (s *TxtRecordService) Search(domainID int64, filter searchFilter, value string) ([]Record, error) { func (s *TxtRecordService) Search(ctx context.Context, domainID int64, filter searchFilter, value string) ([]Record, error) {
endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", "search") endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", "search")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request endpoint: %w", err) return nil, fmt.Errorf("failed to create request endpoint: %w", err)
} }
req, err := http.NewRequest(http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("unable to create request: %w", err)
} }
query := req.URL.Query() query := req.URL.Query()

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -34,7 +35,7 @@ func TestTxtRecordService_Create(t *testing.T) {
} }
}) })
records, err := client.TxtRecords.Create(12345, RecordRequest{}) records, err := client.TxtRecords.Create(context.Background(), 12345, RecordRequest{})
require.NoError(t, err) require.NoError(t, err)
recordsJSON, err := json.Marshal(records) recordsJSON, err := json.Marshal(records)
@ -69,7 +70,7 @@ func TestTxtRecordService_GetAll(t *testing.T) {
} }
}) })
records, err := client.TxtRecords.GetAll(12345) records, err := client.TxtRecords.GetAll(context.Background(), 12345)
require.NoError(t, err) require.NoError(t, err)
recordsJSON, err := json.Marshal(records) recordsJSON, err := json.Marshal(records)
@ -104,7 +105,7 @@ func TestTxtRecordService_Get(t *testing.T) {
} }
}) })
record, err := client.TxtRecords.Get(12345, 6789) record, err := client.TxtRecords.Get(context.Background(), 12345, 6789)
require.NoError(t, err) require.NoError(t, err)
expected := &Record{ expected := &Record{
@ -145,7 +146,7 @@ func TestTxtRecordService_Update(t *testing.T) {
} }
}) })
msg, err := client.TxtRecords.Update(12345, 6789, RecordRequest{}) msg, err := client.TxtRecords.Update(context.Background(), 12345, 6789, RecordRequest{})
require.NoError(t, err) require.NoError(t, err)
expected := &SuccessMessage{Success: "Record updated successfully"} expected := &SuccessMessage{Success: "Record updated successfully"}
@ -168,7 +169,7 @@ func TestTxtRecordService_Delete(t *testing.T) {
} }
}) })
msg, err := client.TxtRecords.Delete(12345, 6789) msg, err := client.TxtRecords.Delete(context.Background(), 12345, 6789)
require.NoError(t, err) require.NoError(t, err)
expected := &SuccessMessage{Success: "Record deleted successfully"} expected := &SuccessMessage{Success: "Record deleted successfully"}
@ -198,7 +199,7 @@ func TestTxtRecordService_Search(t *testing.T) {
} }
}) })
records, err := client.TxtRecords.Search(12345, Exact, "test") records, err := client.TxtRecords.Search(context.Background(), 12345, Exact, "test")
require.NoError(t, err) require.NoError(t, err)
recordsJSON, err := json.Marshal(records) recordsJSON, err := json.Marshal(records)

View file

@ -106,7 +106,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("desec: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) return fmt.Errorf("desec: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
@ -156,7 +156,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("desec: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) return fmt.Errorf("desec: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)

View file

@ -128,12 +128,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("designate: couldn't get zone ID in Present: %w", err) return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
zoneID, err := d.getZoneID(authZone) zoneID, err := d.getZoneID(authZone)
if err != nil { if err != nil {
return fmt.Errorf("designate: %w", err) return fmt.Errorf("designate: couldn't get zone ID in Present: %w", err)
} }
// use mutex to prevent race condition between creating the record and verifying it // use mutex to prevent race condition between creating the record and verifying it
@ -168,7 +168,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return err return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
zoneID, err := d.getZoneID(authZone) zoneID, err := d.getZoneID(authZone)

View file

@ -286,6 +286,9 @@ func setupTestProvider(t *testing.T) string {
t.Helper() t.Helper()
mux := http.NewServeMux() mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(`{ _, _ = w.Write([]byte(`{
"access": { "access": {
@ -319,9 +322,6 @@ func setupTestProvider(t *testing.T) string {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
return server.URL return server.URL
} }

View file

@ -1,131 +0,0 @@
package digitalocean
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/go-acme/lego/v4/challenge/dns01"
)
const defaultBaseURL = "https://api.digitalocean.com"
// txtRecordResponse represents a response from DO's API after making a TXT record.
type txtRecordResponse struct {
DomainRecord record `json:"domain_record"`
}
type record struct {
ID int `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Data string `json:"data,omitempty"`
TTL int `json:"ttl,omitempty"`
}
type apiError struct {
ID string `json:"id"`
Message string `json:"message"`
}
func (d *DNSProvider) removeTxtRecord(domain string, recordID int) error {
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain))
if err != nil {
return fmt.Errorf("could not determine zone for domain %q: %w", domain, err)
}
reqURL := fmt.Sprintf("%s/v2/domains/%s/records/%d", d.config.BaseURL, dns01.UnFqdn(authZone), recordID)
req, err := d.newRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return readError(req, resp)
}
return nil
}
func (d *DNSProvider) addTxtRecord(fqdn, value string) (*txtRecordResponse, error) {
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(fqdn))
if err != nil {
return nil, fmt.Errorf("could not determine zone for domain %q: %w", fqdn, err)
}
reqData := record{Type: "TXT", Name: fqdn, Data: value, TTL: d.config.TTL}
body, err := json.Marshal(reqData)
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/v2/domains/%s/records", d.config.BaseURL, dns01.UnFqdn(authZone))
req, err := d.newRequest(http.MethodPost, reqURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return nil, readError(req, resp)
}
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.New(toUnreadableBodyMessage(req, content))
}
// Everything looks good; but we'll need the ID later to delete the record
respData := &txtRecordResponse{}
err = json.Unmarshal(content, respData)
if err != nil {
return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(req, content))
}
return respData, nil
}
func (d *DNSProvider) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.config.AuthToken))
return req, nil
}
func readError(req *http.Request, resp *http.Response) error {
content, err := io.ReadAll(resp.Body)
if err != nil {
return errors.New(toUnreadableBodyMessage(req, content))
}
var errInfo apiError
err = json.Unmarshal(content, &errInfo)
if err != nil {
return fmt.Errorf("apiError unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content))
}
return fmt.Errorf("HTTP %d: %s: %s", resp.StatusCode, errInfo.ID, errInfo.Message)
}
func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string {
return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody))
}

View file

@ -2,14 +2,17 @@
package digitalocean package digitalocean
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time" "time"
"github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env" "github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/digitalocean/internal"
) )
// Environment variables names. // Environment variables names.
@ -38,7 +41,7 @@ type Config struct {
// NewDefaultConfig returns a default configuration for the DNSProvider. // NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
BaseURL: env.GetOrDefaultString(EnvAPIUrl, defaultBaseURL), BaseURL: env.GetOrDefaultString(EnvAPIUrl, internal.DefaultBaseURL),
TTL: env.GetOrDefaultInt(EnvTTL, 30), TTL: env.GetOrDefaultInt(EnvTTL, 30),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Second), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Second),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 5*time.Second), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 5*time.Second),
@ -51,6 +54,8 @@ func NewDefaultConfig() *Config {
// DNSProvider implements the challenge.Provider interface. // DNSProvider implements the challenge.Provider interface.
type DNSProvider struct { type DNSProvider struct {
config *Config config *Config
client *internal.Client
recordIDs map[string]int recordIDs map[string]int
recordIDsMu sync.Mutex recordIDsMu sync.Mutex
} }
@ -80,12 +85,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("digitalocean: credentials missing") return nil, errors.New("digitalocean: credentials missing")
} }
if config.BaseURL == "" { client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.AuthToken))
config.BaseURL = defaultBaseURL
if config.BaseURL != "" {
var err error
client.BaseURL, err = url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("digitalocean: %w", err)
}
} }
return &DNSProvider{ return &DNSProvider{
config: config, config: config,
client: client,
recordIDs: make(map[string]int), recordIDs: make(map[string]int),
}, nil }, nil
} }
@ -100,7 +112,14 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
respData, err := d.addTxtRecord(info.EffectiveFQDN, info.Value) authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN))
if err != nil {
return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
}
record := internal.Record{Type: "TXT", Name: info.EffectiveFQDN, Data: info.Value, TTL: d.config.TTL}
respData, err := d.client.AddTxtRecord(context.Background(), authZone, record)
if err != nil { if err != nil {
return fmt.Errorf("digitalocean: %w", err) return fmt.Errorf("digitalocean: %w", err)
} }
@ -118,7 +137,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("digitalocean: %w", err) return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err)
} }
// get the record's unique ID from when we created it // get the record's unique ID from when we created it
@ -129,7 +148,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("digitalocean: unknown record ID for '%s'", info.EffectiveFQDN) return fmt.Errorf("digitalocean: unknown record ID for '%s'", info.EffectiveFQDN)
} }
err = d.removeTxtRecord(authZone, recordID) err = d.client.RemoveTxtRecord(context.Background(), authZone, recordID)
if err != nil { if err != nil {
return fmt.Errorf("digitalocean: %w", err) return fmt.Errorf("digitalocean: %w", err)
} }

View file

@ -1,6 +1,7 @@
package digitalocean package digitalocean
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -115,6 +116,7 @@ func TestDNSProvider_Present(t *testing.T) {
mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method, "method") assert.Equal(t, http.MethodPost, r.Method, "method")
assert.Equal(t, "application/json", r.Header.Get("Accept"), "Accept")
assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type") assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type")
assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization") assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization")
@ -125,7 +127,7 @@ func TestDNSProvider_Present(t *testing.T) {
} }
expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}` expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}`
assert.Equal(t, expectedReqBody, string(reqBody)) assert.Equal(t, expectedReqBody, string(bytes.TrimSpace(reqBody)))
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
_, err = fmt.Fprintf(w, `{ _, err = fmt.Fprintf(w, `{
@ -157,7 +159,7 @@ func TestDNSProvider_CleanUp(t *testing.T) {
assert.Equal(t, "/v2/domains/example.com/records/1234567", r.URL.Path, "Path") assert.Equal(t, "/v2/domains/example.com/records/1234567", r.URL.Path, "Path")
// NOTE: Even though the body is empty, DigitalOcean API docs still show setting this Content-Type... assert.Equal(t, "application/json", r.Header.Get("Accept"), "Accept")
assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type") assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type")
assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization") assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization")

View file

@ -0,0 +1,142 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"golang.org/x/oauth2"
)
// DefaultBaseURL default API endpoint.
const DefaultBaseURL = "https://api.digitalocean.com"
// Client the Digital Ocean API client.
type Client struct {
BaseURL *url.URL
httpClient *http.Client
}
// NewClient creates a new Client.
func NewClient(hc *http.Client) *Client {
baseURL, _ := url.Parse(DefaultBaseURL)
if hc == nil {
hc = &http.Client{Timeout: 5 * time.Second}
}
return &Client{BaseURL: baseURL, httpClient: hc}
}
func (c *Client) AddTxtRecord(ctx context.Context, zone string, record Record) (*TxtRecordResponse, error) {
endpoint := c.BaseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return nil, err
}
respData := &TxtRecordResponse{}
err = c.do(req, respData)
if err != nil {
return nil, err
}
return respData, nil
}
func (c *Client) RemoveTxtRecord(ctx context.Context, zone string, recordID int) error {
endpoint := c.BaseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records", strconv.Itoa(recordID))
req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= http.StatusBadRequest {
return parseError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
// NOTE: Even though the body is empty, DigitalOcean API docs still show setting this Content-Type...
req.Header.Set("Content-Type", "application/json")
return req, nil
}
func parseError(req *http.Request, resp *http.Response) error {
raw, _ := io.ReadAll(resp.Body)
var errInfo APIError
err := json.Unmarshal(raw, &errInfo)
if err != nil {
return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)
}
return fmt.Errorf("[status code %d] %w", resp.StatusCode, errInfo)
}
func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
if client == nil {
client = &http.Client{Timeout: 5 * time.Second}
}
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
Base: client.Transport,
}
return client
}

View file

@ -0,0 +1,139 @@
package internal
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, pattern string, handler http.HandlerFunc) *Client {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client := NewClient(OAuthStaticAccessToken(server.Client(), "secret"))
client.BaseURL, _ = url.Parse(server.URL)
mux.HandleFunc(pattern, handler)
return client
}
func checkHeader(req *http.Request, name, value string) error {
val := req.Header.Get(name)
if val != value {
return fmt.Errorf("invalid header value, got: %s want %s", val, value)
}
return nil
}
func writeFixture(rw http.ResponseWriter, filename string) {
file, err := os.Open(filepath.Join("fixtures", filename))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = file.Close() }()
_, _ = io.Copy(rw, file)
}
func TestClient_AddTxtRecord(t *testing.T) {
client := setupTest(t, "/v2/domains/example.com/records", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed)
return
}
err := checkHeader(req, "Accept", "application/json")
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
err = checkHeader(req, "Content-Type", "application/json")
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
err = checkHeader(req, "Authorization", "Bearer secret")
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
reqBody, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}`
if expectedReqBody != string(bytes.TrimSpace(reqBody)) {
http.Error(rw, fmt.Sprintf("unexpected request body: %s", string(bytes.TrimSpace(reqBody))), http.StatusBadRequest)
return
}
rw.WriteHeader(http.StatusCreated)
writeFixture(rw, "domains-records_POST.json")
})
record := Record{
Type: "TXT",
Name: "_acme-challenge.example.com.",
Data: "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI",
TTL: 30,
}
newRecord, err := client.AddTxtRecord(context.Background(), "example.com", record)
require.NoError(t, err)
expected := &TxtRecordResponse{DomainRecord: Record{
ID: 1234567,
Type: "TXT",
Name: "_acme-challenge",
Data: "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI",
TTL: 0,
}}
assert.Equal(t, expected, newRecord)
}
func TestClient_RemoveTxtRecord(t *testing.T) {
client := setupTest(t, "/v2/domains/example.com/records/1234567", func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodDelete {
http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed)
return
}
err := checkHeader(req, "Accept", "application/json")
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
err = checkHeader(req, "Authorization", "Bearer secret")
if err != nil {
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
rw.WriteHeader(http.StatusNoContent)
})
err := client.RemoveTxtRecord(context.Background(), "example.com", 1234567)
require.NoError(t, err)
}

View file

@ -0,0 +1,11 @@
{
"domain_record": {
"id": 1234567,
"type": "TXT",
"name": "_acme-challenge",
"data": "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI",
"priority": null,
"port": null,
"weight": null
}
}

View file

@ -0,0 +1,25 @@
package internal
import "fmt"
// TxtRecordResponse represents a response from DO's API after making a TXT record.
type TxtRecordResponse struct {
DomainRecord Record `json:"domain_record"`
}
type Record struct {
ID int `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Data string `json:"data,omitempty"`
TTL int `json:"ttl,omitempty"`
}
type APIError struct {
ID string `json:"id"`
Message string `json:"message"`
}
func (a APIError) Error() string {
return fmt.Sprintf("%s: %s", a.ID, a.Message)
}

View file

@ -126,7 +126,7 @@ import (
// NewDNSChallengeProviderByName Factory for DNS providers. // NewDNSChallengeProviderByName Factory for DNS providers.
func NewDNSChallengeProviderByName(name string) (challenge.Provider, error) { func NewDNSChallengeProviderByName(name string) (challenge.Provider, error) {
switch name { switch name {
case "acme-dns": case "acme-dns": // TODO(ldez): remove "-" in v5
return acmedns.NewDNSProvider() return acmedns.NewDNSProvider()
case "alidns": case "alidns":
return alidns.NewDNSProvider() return alidns.NewDNSProvider()

View file

@ -2,6 +2,7 @@
package dnshomede package dnshomede
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -99,7 +100,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
func (d *DNSProvider) Present(domain, _, keyAuth string) error { func (d *DNSProvider) Present(domain, _, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Add(dns01.UnFqdn(info.EffectiveFQDN), info.Value) err := d.client.Add(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value)
if err != nil { if err != nil {
return fmt.Errorf("dnshomede: %w", err) return fmt.Errorf("dnshomede: %w", err)
} }
@ -111,7 +112,7 @@ func (d *DNSProvider) Present(domain, _, keyAuth string) error {
func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) info := dns01.GetChallengeInfo(domain, keyAuth)
err := d.client.Remove(dns01.UnFqdn(info.EffectiveFQDN), info.Value) err := d.client.Remove(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value)
if err != nil { if err != nil {
return fmt.Errorf("dnshomede: %w", err) return fmt.Errorf("dnshomede: %w", err)
} }

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -9,6 +10,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
) )
const ( const (
@ -22,8 +25,8 @@ const defaultBaseURL = "https://www.dnshome.de/dyndns.php"
// Client the dnsHome.de client. // Client the dnsHome.de client.
type Client struct { type Client struct {
HTTPClient *http.Client
baseURL string baseURL string
HTTPClient *http.Client
credentials map[string]string credentials map[string]string
credMu sync.Mutex credMu sync.Mutex
@ -40,75 +43,48 @@ func NewClient(credentials map[string]string) *Client {
// Add adds a TXT record. // Add adds a TXT record.
// only one TXT record for ACME is allowed, so it will update the "current" TXT record. // only one TXT record for ACME is allowed, so it will update the "current" TXT record.
func (c *Client) Add(hostname, value string) error { func (c *Client) Add(ctx context.Context, hostname, value string) error {
domain := strings.TrimPrefix(hostname, "_acme-challenge.") domain := strings.TrimPrefix(hostname, "_acme-challenge.")
c.credMu.Lock() return c.doAction(ctx, domain, addAction, value)
password, ok := c.credentials[domain]
c.credMu.Unlock()
if !ok {
return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain)
}
return c.do(url.UserPassword(domain, password), addAction, value)
} }
// Remove removes a TXT record. // Remove removes a TXT record.
// only one TXT record for ACME is allowed, so it will remove "all" the TXT records. // only one TXT record for ACME is allowed, so it will remove "all" the TXT records.
func (c *Client) Remove(hostname, value string) error { func (c *Client) Remove(ctx context.Context, hostname, value string) error {
domain := strings.TrimPrefix(hostname, "_acme-challenge.") domain := strings.TrimPrefix(hostname, "_acme-challenge.")
c.credMu.Lock() return c.doAction(ctx, domain, removeAction, value)
password, ok := c.credentials[domain]
c.credMu.Unlock()
if !ok {
return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain)
}
return c.do(url.UserPassword(domain, password), removeAction, value)
} }
func (c *Client) do(userInfo *url.Userinfo, action, value string) error { func (c *Client) doAction(ctx context.Context, domain, action, value string) error {
if len(value) < 12 { endpoint, err := c.createEndpoint(domain, action, value)
return fmt.Errorf("the TXT value must have more than 12 characters: %s", value)
}
apiEndpoint, err := url.Parse(c.baseURL)
if err != nil { if err != nil {
return err return err
} }
apiEndpoint.User = userInfo req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), http.NoBody)
query := apiEndpoint.Query()
query.Set("acme", action)
query.Set("txt", value)
apiEndpoint.RawQuery = query.Encode()
req, err := http.NewRequest(http.MethodPost, apiEndpoint.String(), http.NoBody)
if err != nil { if err != nil {
return err return fmt.Errorf("unable to create request: %w", err)
} }
resp, err := c.HTTPClient.Do(req) resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return err return errutils.NewHTTPDoError(req, err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
all, _ := io.ReadAll(resp.Body) return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
return fmt.Errorf("%d: %s", resp.StatusCode, string(all))
} }
all, err := io.ReadAll(resp.Body) raw, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return errutils.NewReadResponseError(req, resp.StatusCode, err)
} }
output := string(all) output := string(raw)
if !strings.HasPrefix(output, successCode) { if !strings.HasPrefix(output, successCode) {
return errors.New(output) return errors.New(output)
@ -116,3 +92,31 @@ func (c *Client) do(userInfo *url.Userinfo, action, value string) error {
return nil return nil
} }
func (c *Client) createEndpoint(domain, action, value string) (*url.URL, error) {
if len(value) < 12 {
return nil, fmt.Errorf("the TXT value must have more than 12 characters: %s", value)
}
endpoint, err := url.Parse(c.baseURL)
if err != nil {
return nil, err
}
c.credMu.Lock()
password, ok := c.credentials[domain]
c.credMu.Unlock()
if !ok {
return nil, fmt.Errorf("domain %s not found in credentials, check your credentials map", domain)
}
endpoint.User = url.UserPassword(domain, password)
query := endpoint.Query()
query.Set("acme", action)
query.Set("txt", value)
endpoint.RawQuery = query.Encode()
return endpoint, nil
}

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -9,79 +10,55 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestClient_Add(t *testing.T) { func setupTest(t *testing.T, credentials map[string]string, handler http.HandlerFunc) *Client {
txtValue := "123456789012" t.Helper()
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", handlerMock(addAction, txtValue))
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
t.Cleanup(server.Close)
credentials := map[string]string{ mux.HandleFunc("/", handler)
"example.org": "secret",
}
client := NewClient(credentials) client := NewClient(credentials)
client.HTTPClient = server.Client() client.HTTPClient = server.Client()
client.baseURL = server.URL client.baseURL = server.URL
err := client.Add("example.org", txtValue) return client
}
func TestClient_Add(t *testing.T) {
txtValue := "123456789012"
client := setupTest(t, map[string]string{"example.org": "secret"}, handlerMock(addAction, txtValue))
err := client.Add(context.Background(), "example.org", txtValue)
require.NoError(t, err) require.NoError(t, err)
} }
func TestClient_Add_error(t *testing.T) { func TestClient_Add_error(t *testing.T) {
txtValue := "123456789012" txtValue := "123456789012"
mux := http.NewServeMux() client := setupTest(t, map[string]string{"example.com": "secret"}, handlerMock(addAction, txtValue))
mux.HandleFunc("/", handlerMock(addAction, txtValue))
server := httptest.NewServer(mux)
credentials := map[string]string{ err := client.Add(context.Background(), "example.org", txtValue)
"example.com": "secret",
}
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Add("example.org", txtValue)
require.Error(t, err) require.Error(t, err)
} }
func TestClient_Remove(t *testing.T) { func TestClient_Remove(t *testing.T) {
txtValue := "ABCDEFGHIJKL" txtValue := "ABCDEFGHIJKL"
mux := http.NewServeMux() client := setupTest(t, map[string]string{"example.org": "secret"}, handlerMock(removeAction, txtValue))
mux.HandleFunc("/", handlerMock(removeAction, txtValue))
server := httptest.NewServer(mux)
credentials := map[string]string{ err := client.Remove(context.Background(), "example.org", txtValue)
"example.org": "secret",
}
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Remove("example.org", txtValue)
require.NoError(t, err) require.NoError(t, err)
} }
func TestClient_Remove_error(t *testing.T) { func TestClient_Remove_error(t *testing.T) {
txtValue := "ABCDEFGHIJKL" txtValue := "ABCDEFGHIJKL"
mux := http.NewServeMux() client := setupTest(t, map[string]string{"example.com": "secret"}, handlerMock(removeAction, txtValue))
mux.HandleFunc("/", handlerMock(removeAction, txtValue))
server := httptest.NewServer(mux)
credentials := map[string]string{ err := client.Remove(context.Background(), "example.org", txtValue)
"example.com": "secret",
}
client := NewClient(credentials)
client.HTTPClient = server.Client()
client.baseURL = server.URL
err := client.Remove("example.org", txtValue)
require.Error(t, err) require.Error(t, err)
} }

View file

@ -149,7 +149,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
func (d *DNSProvider) getHostedZone(domain string) (string, error) { func (d *DNSProvider) getHostedZone(domain string) (string, error) {
authZone, err := dns01.FindZoneByFqdn(domain) authZone, err := dns01.FindZoneByFqdn(domain)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err)
} }
accountID, err := d.getAccountID() accountID, err := d.getAccountID()

View file

@ -2,10 +2,12 @@
package dnsmadeeasy package dnsmadeeasy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@ -86,12 +88,12 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
var baseURL string var baseURL string
if config.Sandbox { if config.Sandbox {
baseURL = "https://api.sandbox.dnsmadeeasy.com/V2.0" baseURL = internal.DefaultSandboxBaseURL
} else {
if config.BaseURL == "" {
baseURL = internal.DefaultProdBaseURL
} else { } else {
if len(config.BaseURL) > 0 {
baseURL = config.BaseURL baseURL = config.BaseURL
} else {
baseURL = "https://api.dnsmadeeasy.com/V2.0"
} }
} }
@ -101,7 +103,10 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
} }
client.HTTPClient = config.HTTPClient client.HTTPClient = config.HTTPClient
client.BaseURL = baseURL client.BaseURL, err = url.Parse(baseURL)
if err != nil {
return nil, err
}
return &DNSProvider{ return &DNSProvider{
client: client, client: client,
@ -115,11 +120,13 @@ func (d *DNSProvider) Present(domainName, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to find zone for %s: %w", info.EffectiveFQDN, err) return fmt.Errorf("dnsmadeeasy: could not find zone for domain %q (%s): %w", domainName, info.EffectiveFQDN, err)
} }
ctx := context.Background()
// fetch the domain details // fetch the domain details
domain, err := d.client.GetDomain(authZone) domain, err := d.client.GetDomain(ctx, authZone)
if err != nil { if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err) return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err)
} }
@ -128,7 +135,7 @@ func (d *DNSProvider) Present(domainName, token, keyAuth string) error {
name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1) name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1)
record := &internal.Record{Type: "TXT", Name: name, Value: info.Value, TTL: d.config.TTL} record := &internal.Record{Type: "TXT", Name: name, Value: info.Value, TTL: d.config.TTL}
err = d.client.CreateRecord(domain, record) err = d.client.CreateRecord(ctx, domain, record)
if err != nil { if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to create record for %s: %w", name, err) return fmt.Errorf("dnsmadeeasy: unable to create record for %s: %w", name, err)
} }
@ -141,18 +148,20 @@ func (d *DNSProvider) CleanUp(domainName, token, keyAuth string) error {
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil { if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to find zone for %s: %w", info.EffectiveFQDN, err) return fmt.Errorf("dnsmadeeasy: could not find zone for domain %q (%s): %w", domainName, info.EffectiveFQDN, err)
} }
ctx := context.Background()
// fetch the domain details // fetch the domain details
domain, err := d.client.GetDomain(authZone) domain, err := d.client.GetDomain(ctx, authZone)
if err != nil { if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err) return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err)
} }
// find matching records // find matching records
name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1) name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1)
records, err := d.client.GetRecords(domain, name, "TXT") records, err := d.client.GetRecords(ctx, domain, name, "TXT")
if err != nil { if err != nil {
return fmt.Errorf("dnsmadeeasy: unable to get records for domain %s: %w", domain.Name, err) return fmt.Errorf("dnsmadeeasy: unable to get records for domain %s: %w", domain.Name, err)
} }
@ -160,7 +169,7 @@ func (d *DNSProvider) CleanUp(domainName, token, keyAuth string) error {
// delete records // delete records
var lastError error var lastError error
for _, record := range *records { for _, record := range *records {
err = d.client.DeleteRecord(record) err = d.client.DeleteRecord(ctx, record)
if err != nil { if err != nil {
lastError = fmt.Errorf("dnsmadeeasy: unable to delete record [id=%d, name=%s]: %w", record.ID, record.Name, err) lastError = fmt.Errorf("dnsmadeeasy: unable to delete record [id=%d, name=%s]: %w", record.ID, record.Name, err)
} }

Some files were not shown because too many files have changed in this diff Show more