lego/providers/dns/yandex/internal/client_test.go
2024-03-20 20:30:35 +01:00

328 lines
7.2 KiB
Go

package internal
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T) (*Client, *http.ServeMux) {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
client, err := NewClient("lego")
require.NoError(t, err)
client.HTTPClient = server.Client()
client.baseURL, _ = url.Parse(server.URL)
return client, mux
}
func TestAddRecord(t *testing.T) {
testCases := []struct {
desc string
handler http.HandlerFunc
data Record
expectError bool
}{
{
desc: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "lego", r.Header.Get(pddTokenHeader))
err := r.ParseForm()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, `content=txtTXTtxtTXTtxtTXT&domain=example.com&subdomain=foo&ttl=300&type=TXT`, r.PostForm.Encode())
response := AddResponse{
Domain: "example.com",
Record: &Record{
ID: 1,
Type: "TXT",
Domain: "example.com",
SubDomain: "foo",
FQDN: "foo.example.com.",
Content: "txtTXTtxtTXTtxtTXT",
TTL: 300,
},
BaseResponse: BaseResponse{
Success: "ok",
},
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
},
data: Record{
Domain: "example.com",
Type: "TXT",
Content: "txtTXTtxtTXTtxtTXT",
SubDomain: "foo",
TTL: 300,
},
},
{
desc: "error",
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "lego", r.Header.Get(pddTokenHeader))
err := r.ParseForm()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, `content=txtTXTtxtTXTtxtTXT&domain=example.com&subdomain=foo&ttl=300&type=TXT`, r.PostForm.Encode())
response := AddResponse{
Domain: "example.com",
BaseResponse: BaseResponse{
Success: "error",
Error: "bad things",
},
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
},
data: Record{
Domain: "example.com",
Type: "TXT",
Content: "txtTXTtxtTXTtxtTXT",
SubDomain: "foo",
TTL: 300,
},
expectError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/add", test.handler)
record, err := client.AddRecord(context.Background(), test.data)
if test.expectError {
require.Error(t, err)
require.Nil(t, record)
} else {
require.NoError(t, err)
require.NotNil(t, record)
}
})
}
}
func TestRemoveRecord(t *testing.T) {
testCases := []struct {
desc string
handler http.HandlerFunc
data Record
expectError bool
}{
{
desc: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "lego", r.Header.Get(pddTokenHeader))
err := r.ParseForm()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, `domain=example.com&record_id=6`, r.PostForm.Encode())
response := RemoveResponse{
Domain: "example.com",
RecordID: 6,
BaseResponse: BaseResponse{
Success: "ok",
},
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
},
data: Record{
ID: 6,
Domain: "example.com",
},
},
{
desc: "error",
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "lego", r.Header.Get(pddTokenHeader))
err := r.ParseForm()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, `domain=example.com&record_id=6`, r.PostForm.Encode())
response := RemoveResponse{
Domain: "example.com",
RecordID: 6,
BaseResponse: BaseResponse{
Success: "error",
Error: "bad things",
},
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
},
data: Record{
ID: 6,
Domain: "example.com",
},
expectError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
client, mux := setupTest(t)
mux.HandleFunc("/del", test.handler)
id, err := client.RemoveRecord(context.Background(), test.data)
if test.expectError {
require.Error(t, err)
require.Equal(t, 0, id)
} else {
require.NoError(t, err)
require.Equal(t, 6, id)
}
})
}
}
func TestGetRecords(t *testing.T) {
testCases := []struct {
desc string
handler http.HandlerFunc
domain string
expectError bool
}{
{
desc: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "lego", r.Header.Get(pddTokenHeader))
assert.Equal(t, "domain=example.com", r.URL.RawQuery)
response := ListResponse{
Domain: "example.com",
Records: []Record{
{
ID: 1,
Type: "TXT",
Domain: "example.com",
SubDomain: "foo",
FQDN: "foo.example.com.",
Content: "txtTXTtxtTXTtxtTXT",
TTL: 300,
},
{
ID: 2,
Type: "NS",
Domain: "example.com",
SubDomain: "foo",
FQDN: "foo.example.com.",
Content: "bar",
TTL: 300,
},
},
BaseResponse: BaseResponse{
Success: "ok",
},
}
err := json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
},
domain: "example.com",
},
{
desc: "error",
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "lego", r.Header.Get(pddTokenHeader))
assert.Equal(t, "domain=example.com", r.URL.RawQuery)
response := ListResponse{
Domain: "example.com",
BaseResponse: BaseResponse{
Success: "error",
Error: "bad things",
},
}
err := json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
},
domain: "example.com",
expectError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
client, mux := setupTest(t)
mux.HandleFunc("/list", test.handler)
records, err := client.GetRecords(context.Background(), test.domain)
if test.expectError {
require.Error(t, err)
require.Empty(t, records)
} else {
require.NoError(t, err)
require.Len(t, records, 2)
}
})
}
}