lego/providers/dns/loopia/loopia_mock_test.go
2024-02-08 03:16:48 +01:00

236 lines
6 KiB
Go

package loopia
import (
"context"
"errors"
"testing"
"github.com/go-acme/lego/v4/providers/dns/loopia/internal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
exampleDomain = "example.com"
exampleSubDomain = "_acme-challenge"
exampleRdata = "LHDhK3oGRvkiefQnx7OOczTY5Tic_xZ6HcMOc_gmtoM"
)
func TestDNSProvider_Present(t *testing.T) {
mockedFindZoneByFqdn := func(fqdn string) (string, error) {
return exampleDomain + ".", nil
}
testCases := []struct {
desc string
getTXTRecordsError error
getTXTRecordsReturn []internal.RecordObj
addTXTRecordError error
callAddTXTRecord bool
callGetTXTRecords bool
expectedError string
expectedInProgressTokenInfo int
}{
{
desc: "Present OK",
getTXTRecordsReturn: []internal.RecordObj{{Type: "TXT", Rdata: exampleRdata, RecordID: 12345678}},
callAddTXTRecord: true,
callGetTXTRecords: true,
expectedInProgressTokenInfo: 12345678,
},
{
desc: "AddTXTRecord fails",
addTXTRecordError: errors.New("unknown error: 'ADDTXT'"),
callAddTXTRecord: true,
expectedError: "loopia: failed to add TXT record: unknown error: 'ADDTXT'",
},
{
desc: "GetTXTRecords fails",
getTXTRecordsError: errors.New("unknown error: 'GETTXT'"),
callAddTXTRecord: true,
callGetTXTRecords: true,
expectedError: "loopia: failed to get TXT records: unknown error: 'GETTXT'",
},
{
desc: "Failed to get ID",
callAddTXTRecord: true,
callGetTXTRecords: true,
expectedError: "loopia: failed to find the stored TXT record",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
config := NewDefaultConfig()
config.APIUser = "apiuser"
config.APIPassword = "password"
client := &mockedClient{}
provider, err := NewDNSProviderConfig(config)
require.NoError(t, err)
provider.findZoneByFqdn = mockedFindZoneByFqdn
provider.client = client
if test.callAddTXTRecord {
client.On("AddTXTRecord", exampleDomain, exampleSubDomain, config.TTL, exampleRdata).Return(test.addTXTRecordError)
}
if test.callGetTXTRecords {
client.On("GetTXTRecords", exampleDomain, exampleSubDomain).Return(test.getTXTRecordsReturn, test.getTXTRecordsError)
}
err = provider.Present(exampleDomain, "token", "key")
client.AssertExpectations(t)
if test.expectedError == "" {
require.NoError(t, err)
assert.Equal(t, test.expectedInProgressTokenInfo, provider.inProgressInfo["token"])
} else {
require.Error(t, err)
assert.EqualError(t, err, test.expectedError)
}
})
}
}
func TestDNSProvider_Cleanup(t *testing.T) {
mockedFindZoneByFqdn := func(fqdn string) (string, error) {
return "example.com.", nil
}
testCases := []struct {
desc string
getTXTRecordsError error
getTXTRecordsReturn []internal.RecordObj
removeTXTRecordError error
removeSubdomainError error
callAddTXTRecord bool
callGetTXTRecords bool
callRemoveSubdomain bool
expectedError string
}{
{
desc: "Cleanup Ok",
callAddTXTRecord: true,
callGetTXTRecords: true,
callRemoveSubdomain: true,
},
{
desc: "removeTXTRecord failed",
removeTXTRecordError: errors.New("authentication error"),
callAddTXTRecord: true,
expectedError: "loopia: failed to remove TXT record: authentication error",
},
{
desc: "removeSubdomain failed",
removeSubdomainError: errors.New(`unknown error: "UNKNOWN_ERROR"`),
callAddTXTRecord: true,
callGetTXTRecords: true,
callRemoveSubdomain: true,
expectedError: `loopia: failed to remove subdomain: unknown error: "UNKNOWN_ERROR"`,
},
{
desc: "Don't call removeSubdomain when records",
getTXTRecordsReturn: []internal.RecordObj{{Type: "TXT", Rdata: "LEFTOVER"}},
callAddTXTRecord: true,
callGetTXTRecords: true,
callRemoveSubdomain: false,
},
{
desc: "getTXTRecords failed",
getTXTRecordsError: errors.New(`unknown error: "UNKNOWN_ERROR"`),
callAddTXTRecord: true,
callGetTXTRecords: true,
callRemoveSubdomain: false,
expectedError: `loopia: failed to get TXT records: unknown error: "UNKNOWN_ERROR"`,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
config := NewDefaultConfig()
config.APIUser = "apiuser"
config.APIPassword = "password"
client := &mockedClient{}
provider, err := NewDNSProviderConfig(config)
require.NoError(t, err)
provider.findZoneByFqdn = mockedFindZoneByFqdn
provider.client = client
provider.inProgressInfo["token"] = 12345678
if test.callAddTXTRecord {
client.On("RemoveTXTRecord", "example.com", "_acme-challenge", 12345678).Return(test.removeTXTRecordError)
}
if test.callGetTXTRecords {
client.On("GetTXTRecords", "example.com", "_acme-challenge").Return(test.getTXTRecordsReturn, test.getTXTRecordsError)
}
if test.callRemoveSubdomain {
client.On("RemoveSubdomain", "example.com", "_acme-challenge").Return(test.removeSubdomainError)
}
err = provider.CleanUp("example.com", "token", "key")
client.AssertExpectations(t)
if test.expectedError == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.EqualError(t, err, test.expectedError)
}
})
}
}
type mockedClient struct {
mock.Mock
}
func (c *mockedClient) RemoveTXTRecord(ctx context.Context, domain string, subdomain string, recordID int) error {
args := c.Called(domain, subdomain, recordID)
return args.Error(0)
}
func (c *mockedClient) AddTXTRecord(ctx context.Context, domain string, subdomain string, ttl int, value string) error {
args := c.Called(domain, subdomain, ttl, value)
return args.Error(0)
}
func (c *mockedClient) GetTXTRecords(ctx context.Context, domain string, subdomain string) ([]internal.RecordObj, error) {
args := c.Called(domain, subdomain)
return args.Get(0).([]internal.RecordObj), args.Error(1)
}
func (c *mockedClient) RemoveSubdomain(ctx context.Context, domain, subdomain string) error {
args := c.Called(domain, subdomain)
return args.Error(0)
}