From 6bc93456ad70982700dbe5a5135fc47cddd767b1 Mon Sep 17 00:00:00 2001
From: akillcool <akillcool@outlook.com>
Date: Thu, 19 Mar 2020 16:26:48 +0800
Subject: [PATCH] alicloud: add support for domain with punycode (#1088)

---
 providers/dns/alidns/alidns.go | 42 +++++++++++++++++++++++++---------
 1 file changed, 31 insertions(+), 11 deletions(-)

diff --git a/providers/dns/alidns/alidns.go b/providers/dns/alidns/alidns.go
index 8274f4a8..21da91c2 100644
--- a/providers/dns/alidns/alidns.go
+++ b/providers/dns/alidns/alidns.go
@@ -13,6 +13,7 @@ import (
 	"github.com/aliyun/alibaba-cloud-sdk-go/services/alidns"
 	"github.com/go-acme/lego/v3/challenge/dns01"
 	"github.com/go-acme/lego/v3/platform/config/env"
+	"golang.org/x/net/idna"
 )
 
 const defaultRegionID = "cn-hangzhou"
@@ -114,7 +115,10 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
 		return fmt.Errorf("alicloud: %w", err)
 	}
 
-	recordAttributes := d.newTxtRecord(zoneName, fqdn, value)
+	recordAttributes, err := d.newTxtRecord(zoneName, fqdn, value)
+	if err != nil {
+		return err
+	}
 
 	_, err = d.client.AddDomainRecord(recordAttributes)
 	if err != nil {
@@ -178,7 +182,7 @@ func (d *DNSProvider) getHostedZone(domain string) (string, error) {
 
 	var hostedZone alidns.Domain
 	for _, zone := range domains {
-		if zone.DomainName == dns01.UnFqdn(authZone) {
+		if zone.DomainName == dns01.UnFqdn(authZone) || zone.PunyCode == dns01.UnFqdn(authZone) {
 			hostedZone = zone
 		}
 	}
@@ -190,14 +194,21 @@ func (d *DNSProvider) getHostedZone(domain string) (string, error) {
 	return hostedZone.DomainName, nil
 }
 
-func (d *DNSProvider) newTxtRecord(zone, fqdn, value string) *alidns.AddDomainRecordRequest {
+func (d *DNSProvider) newTxtRecord(zone, fqdn, value string) (*alidns.AddDomainRecordRequest, error) {
 	request := alidns.CreateAddDomainRecordRequest()
 	request.Type = "TXT"
 	request.DomainName = zone
-	request.RR = d.extractRecordName(fqdn, zone)
+
+	var err error
+	request.RR, err = d.extractRecordName(fqdn, zone)
+	if err != nil {
+		return nil, err
+	}
+
 	request.Value = value
 	request.TTL = requests.NewInteger(d.config.TTL)
-	return request
+
+	return request, nil
 }
 
 func (d *DNSProvider) findTxtRecords(domain, fqdn string) ([]alidns.Record, error) {
@@ -217,7 +228,11 @@ func (d *DNSProvider) findTxtRecords(domain, fqdn string) ([]alidns.Record, erro
 		return records, fmt.Errorf("API call has failed: %w", err)
 	}
 
-	recordName := d.extractRecordName(fqdn, zoneName)
+	recordName, err := d.extractRecordName(fqdn, zoneName)
+	if err != nil {
+		return nil, err
+	}
+
 	for _, record := range result.DomainRecords.Record {
 		if record.RR == recordName {
 			records = append(records, record)
@@ -226,10 +241,15 @@ func (d *DNSProvider) findTxtRecords(domain, fqdn string) ([]alidns.Record, erro
 	return records, nil
 }
 
-func (d *DNSProvider) extractRecordName(fqdn, domain string) string {
-	name := dns01.UnFqdn(fqdn)
-	if idx := strings.Index(name, "."+domain); idx != -1 {
-		return name[:idx]
+func (d *DNSProvider) extractRecordName(fqdn, domain string) (string, error) {
+	asciiDomain, err := idna.ToASCII(domain)
+	if err != nil {
+		return "", fmt.Errorf("fail to convert punycode: %w", err)
 	}
-	return name
+
+	name := dns01.UnFqdn(fqdn)
+	if idx := strings.Index(name, "."+asciiDomain); idx != -1 {
+		return name[:idx], nil
+	}
+	return name, nil
 }