diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c51ff9b8..0c987a72 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -74,5 +74,6 @@ owners to license your work under the terms of the [MIT License](LICENSE). | RFC2136 | `rfc2136` | [documentation](https://tools.ietf.org/html/rfc2136) | - | | Route 53 | `route53` | [documentation](https://docs.aws.amazon.com/Route53/latest/APIReference/API_Operations_Amazon_Route_53.html) | [Go client](https://github.com/aws/aws-sdk-go/aws) | | Sakura Cloud | `sakuracloud` | [documentation](https://developer.sakura.ad.jp/cloud/api/1.1/) | [Go client](https://github.com/sacloud/libsacloud) | +| Stackpath | `stackpath` | [documentation](https://developer.stackpath.com/en/api/dns/#tag/Zone) | - | | VegaDNS | `vegadns` | [documentation](https://github.com/shupp/VegaDNS-API) | [Go client](https://github.com/OpenDNS/vegadns2client) | | Vultr | `vultr` | [documentation](https://www.vultr.com/api/#dns) | [Go client](https://github.com/JamesClonk/vultr) | diff --git a/Gopkg.lock b/Gopkg.lock index 428fe61a..5efa5285 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -445,10 +445,11 @@ [[projects]] branch = "master" - digest = "1:bc2b221d465bb28ce46e8d472ecdc424b9a9b541bd61d8c311c5f29c8dd75b1b" + digest = "1:b20a60bb1085d4c535af064faf3e74b4e185781b58bba1ad7406cd9733d82403" name = "golang.org/x/oauth2" packages = [ ".", + "clientcredentials", "google", "internal", "jws", @@ -615,7 +616,9 @@ "github.com/urfave/cli", "golang.org/x/crypto/ocsp", "golang.org/x/net/context", + "golang.org/x/net/publicsuffix", "golang.org/x/oauth2", + "golang.org/x/oauth2/clientcredentials", "golang.org/x/oauth2/google", "google.golang.org/api/dns/v1", "gopkg.in/ns1/ns1-go.v2/rest", diff --git a/cli.go b/cli.go index 09549ece..102bb5a4 100644 --- a/cli.go +++ b/cli.go @@ -241,6 +241,7 @@ Here is an example bash command using the CloudFlare DNS provider: fmt.Fprintln(w, "\trfc2136:\tRFC2136_TSIG_KEY, RFC2136_TSIG_SECRET,\n\t\tRFC2136_TSIG_ALGORITHM, RFC2136_NAMESERVER") fmt.Fprintln(w, "\troute53:\tAWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION, AWS_HOSTED_ZONE_ID") fmt.Fprintln(w, "\tsakuracloud:\tSAKURACLOUD_ACCESS_TOKEN, SAKURACLOUD_ACCESS_TOKEN_SECRET") + fmt.Fprintln(w, "\tstackpath:\tSTACKPATH_CLIENT_ID, STACKPATH_CLIENT_SECRET, STACKPATH_STACK_ID") fmt.Fprintln(w, "\tvegadns:\tSECRET_VEGADNS_KEY, SECRET_VEGADNS_SECRET, VEGADNS_URL") fmt.Fprintln(w, "\tvultr:\tVULTR_API_KEY") fmt.Fprintln(w) @@ -282,6 +283,7 @@ Here is an example bash command using the CloudFlare DNS provider: fmt.Fprintln(w, "\trfc2136:\tRFC2136_POLLING_INTERVAL, RFC2136_PROPAGATION_TIMEOUT, RFC2136_TTL") fmt.Fprintln(w, "\troute53:\tAWS_POLLING_INTERVAL, AWS_PROPAGATION_TIMEOUT, AWS_TTL") fmt.Fprintln(w, "\tsakuracloud:\tSAKURACLOUD_POLLING_INTERVAL, SAKURACLOUD_PROPAGATION_TIMEOUT, SAKURACLOUD_TTL") + fmt.Fprintln(w, "\tstackpath:\tSTACKPATH_POLLING_INTERVAL, STACKPATH_PROPAGATION_TIMEOUT, STACKPATH_TTL") fmt.Fprintln(w, "\tvegadns:\tVEGADNS_POLLING_INTERVAL, VEGADNS_PROPAGATION_TIMEOUT, VEGADNS_TTL") fmt.Fprintln(w, "\tvultr:\tVULTR_POLLING_INTERVAL, VULTR_PROPAGATION_TIMEOUT, VULTR_TTL, VULTR_HTTP_TIMEOUT") diff --git a/providers/dns/dns_providers.go b/providers/dns/dns_providers.go index a3a5f1ef..37734663 100644 --- a/providers/dns/dns_providers.go +++ b/providers/dns/dns_providers.go @@ -42,6 +42,7 @@ import ( "github.com/xenolf/lego/providers/dns/rfc2136" "github.com/xenolf/lego/providers/dns/route53" "github.com/xenolf/lego/providers/dns/sakuracloud" + "github.com/xenolf/lego/providers/dns/stackpath" "github.com/xenolf/lego/providers/dns/vegadns" "github.com/xenolf/lego/providers/dns/vultr" ) @@ -127,6 +128,8 @@ func NewDNSChallengeProviderByName(name string) (acme.ChallengeProvider, error) return rfc2136.NewDNSProvider() case "sakuracloud": return sakuracloud.NewDNSProvider() + case "stackpath": + return stackpath.NewDNSProvider() case "vegadns": return vegadns.NewDNSProvider() case "vultr": diff --git a/providers/dns/stackpath/client.go b/providers/dns/stackpath/client.go new file mode 100644 index 00000000..495d8c55 --- /dev/null +++ b/providers/dns/stackpath/client.go @@ -0,0 +1,217 @@ +package stackpath + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "path" + + "github.com/xenolf/lego/acme" + "golang.org/x/net/publicsuffix" +) + +// Zones is the response struct from the Stackpath api GetZones +type Zones struct { + Zones []Zone `json:"zones"` +} + +// Zone a DNS zone representation +type Zone struct { + ID string + Domain string +} + +// Records is the response struct from the Stackpath api GetZoneRecords +type Records struct { + Records []Record `json:"records"` +} + +// Record a DNS record representation +type Record struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Type string `json:"type"` + TTL int `json:"ttl"` + Data string `json:"data"` +} + +// ErrorResponse the API error response representation +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"error"` +} + +func (e *ErrorResponse) Error() string { + return fmt.Sprintf("%d %s", e.Code, e.Message) +} + +// https://developer.stackpath.com/en/api/dns/#operation/GetZones +func (d *DNSProvider) getZones(domain string) (*Zone, error) { + domain = acme.UnFqdn(domain) + tld, err := publicsuffix.EffectiveTLDPlusOne(domain) + if err != nil { + return nil, err + } + + req, err := d.newRequest(http.MethodGet, "/zones", nil) + if err != nil { + return nil, err + } + + query := req.URL.Query() + query.Add("page_request.filter", fmt.Sprintf("domain='%s'", tld)) + req.URL.RawQuery = query.Encode() + + var zones Zones + err = d.do(req, &zones) + if err != nil { + return nil, err + } + + if len(zones.Zones) == 0 { + return nil, fmt.Errorf("did not find zone with domain %s", domain) + } + + return &zones.Zones[0], nil +} + +// https://developer.stackpath.com/en/api/dns/#operation/GetZoneRecords +func (d *DNSProvider) getZoneRecords(name string, zone *Zone) ([]Record, error) { + u := fmt.Sprintf("/zones/%s/records", zone.ID) + req, err := d.newRequest(http.MethodGet, u, nil) + if err != nil { + return nil, err + } + + query := req.URL.Query() + query.Add("page_request.filter", fmt.Sprintf("name='%s' and type='TXT'", name)) + req.URL.RawQuery = query.Encode() + + var records Records + err = d.do(req, &records) + if err != nil { + return nil, err + } + + if len(records.Records) == 0 { + return nil, fmt.Errorf("did not find record with name %s", name) + } + + return records.Records, nil +} + +// https://developer.stackpath.com/en/api/dns/#operation/CreateZoneRecord +func (d *DNSProvider) createZoneRecord(zone *Zone, record Record) error { + u := fmt.Sprintf("/zones/%s/records", zone.ID) + req, err := d.newRequest(http.MethodPost, u, record) + if err != nil { + return err + } + + return d.do(req, nil) +} + +// https://developer.stackpath.com/en/api/dns/#operation/DeleteZoneRecord +func (d *DNSProvider) deleteZoneRecord(zone *Zone, record Record) error { + u := fmt.Sprintf("/zones/%s/records/%s", zone.ID, record.ID) + req, err := d.newRequest(http.MethodDelete, u, nil) + if err != nil { + return err + } + + return d.do(req, nil) +} + +func (d *DNSProvider) newRequest(method, urlStr string, body interface{}) (*http.Request, error) { + u, err := d.BaseURL.Parse(path.Join(d.config.StackID, urlStr)) + if err != nil { + return nil, err + } + + if body == nil { + var req *http.Request + req, err = http.NewRequest(method, u.String(), nil) + if err != nil { + return nil, err + } + + return req, nil + } + + reqBody, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(reqBody)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + return req, nil +} + +func (d *DNSProvider) do(req *http.Request, v interface{}) error { + resp, err := d.client.Do(req) + if err != nil { + return err + } + + err = checkResponse(resp) + if err != nil { + return err + } + + if v == nil { + return nil + } + + raw, err := readBody(resp) + if err != nil { + return fmt.Errorf("failed to read body: %v", err) + } + + err = json.Unmarshal(raw, v) + if err != nil { + return fmt.Errorf("unmarshaling error: %v: %s", err, string(raw)) + } + + return nil +} + +func checkResponse(resp *http.Response) error { + if resp.StatusCode > 299 { + data, err := readBody(resp) + if err != nil { + return &ErrorResponse{Code: resp.StatusCode, Message: err.Error()} + } + + errResp := &ErrorResponse{} + err = json.Unmarshal(data, errResp) + if err != nil { + return &ErrorResponse{Code: resp.StatusCode, Message: fmt.Sprintf("unmarshaling error: %v: %s", err, string(data))} + } + return errResp + } + + return nil +} + +func readBody(resp *http.Response) ([]byte, error) { + if resp.Body == nil { + return nil, fmt.Errorf("response body is nil") + } + + defer resp.Body.Close() + + rawBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return rawBody, nil +} diff --git a/providers/dns/stackpath/stackpath.go b/providers/dns/stackpath/stackpath.go new file mode 100644 index 00000000..4c247def --- /dev/null +++ b/providers/dns/stackpath/stackpath.go @@ -0,0 +1,150 @@ +// Package stackpath implements a DNS provider for solving the DNS-01 challenge using Stackpath DNS. +// https://developer.stackpath.com/en/api/dns/ +package stackpath + +import ( + "context" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/xenolf/lego/acme" + "github.com/xenolf/lego/platform/config/env" + "golang.org/x/oauth2/clientcredentials" +) + +const ( + defaultBaseURL = "https://gateway.stackpath.com/dns/v1/stacks/" + defaultAuthURL = "https://gateway.stackpath.com/identity/v1/oauth2/token" +) + +// Config is used to configure the creation of the DNSProvider +type Config struct { + ClientID string + ClientSecret string + StackID string + TTL int + PropagationTimeout time.Duration + PollingInterval time.Duration +} + +// NewDefaultConfig returns a default configuration for the DNSProvider +func NewDefaultConfig() *Config { + return &Config{ + TTL: env.GetOrDefaultInt("STACKPATH_TTL", 120), + PropagationTimeout: env.GetOrDefaultSecond("STACKPATH_PROPAGATION_TIMEOUT", acme.DefaultPropagationTimeout), + PollingInterval: env.GetOrDefaultSecond("STACKPATH_POLLING_INTERVAL", acme.DefaultPollingInterval), + } +} + +// DNSProvider is an implementation of the acme.ChallengeProvider interface. +type DNSProvider struct { + BaseURL *url.URL + client *http.Client + config *Config +} + +// NewDNSProvider returns a DNSProvider instance configured for Stackpath. +// Credentials must be passed in the environment variables: +// STACKPATH_CLIENT_ID, STACKPATH_CLIENT_SECRET, and STACKPATH_STACK_ID. +func NewDNSProvider() (*DNSProvider, error) { + values, err := env.Get("STACKPATH_CLIENT_ID", "STACKPATH_CLIENT_SECRET", "STACKPATH_STACK_ID") + if err != nil { + return nil, fmt.Errorf("stackpath: %v", err) + } + + config := NewDefaultConfig() + config.ClientID = values["STACKPATH_CLIENT_ID"] + config.ClientSecret = values["STACKPATH_CLIENT_SECRET"] + config.StackID = values["STACKPATH_STACK_ID"] + + return NewDNSProviderConfig(config) +} + +// NewDNSProviderConfig return a DNSProvider instance configured for Stackpath. +func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { + if config == nil { + return nil, errors.New("stackpath: the configuration of the DNS provider is nil") + } + + if len(config.ClientID) == 0 || len(config.ClientSecret) == 0 { + return nil, errors.New("stackpath: credentials missing") + } + + if len(config.StackID) == 0 { + return nil, errors.New("stackpath: stack id missing") + } + + baseURL, _ := url.Parse(defaultBaseURL) + + return &DNSProvider{ + BaseURL: baseURL, + client: getOathClient(config), + config: config, + }, nil +} + +func getOathClient(config *Config) *http.Client { + oathConfig := &clientcredentials.Config{ + TokenURL: defaultAuthURL, + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + } + + return oathConfig.Client(context.Background()) +} + +// Present creates a TXT record to fulfill the dns-01 challenge +func (d *DNSProvider) Present(domain, token, keyAuth string) error { + zone, err := d.getZones(domain) + if err != nil { + return fmt.Errorf("stackpath: %v", err) + } + + fqdn, value, _ := acme.DNS01Record(domain, keyAuth) + parts := strings.Split(fqdn, ".") + + record := Record{ + Name: parts[0], + Type: "TXT", + TTL: d.config.TTL, + Data: value, + } + + return d.createZoneRecord(zone, record) +} + +// CleanUp removes the TXT record matching the specified parameters +func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { + zone, err := d.getZones(domain) + if err != nil { + return fmt.Errorf("stackpath: %v", err) + } + + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) + parts := strings.Split(fqdn, ".") + + records, err := d.getZoneRecords(parts[0], zone) + if err != nil { + return err + } + + for _, record := range records { + err = d.deleteZoneRecord(zone, record) + if err != nil { + log.Printf("stackpath: failed to delete TXT record: %v", err) + } + } + + return nil +} + +// Timeout returns the timeout and interval to use when checking for DNS propagation. +// Adjusting here to cope with spikes in propagation times. +func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { + return d.config.PropagationTimeout, d.config.PollingInterval +} diff --git a/providers/dns/stackpath/stackpath_test.go b/providers/dns/stackpath/stackpath_test.go new file mode 100644 index 00000000..9ee42dc2 --- /dev/null +++ b/providers/dns/stackpath/stackpath_test.go @@ -0,0 +1,315 @@ +package stackpath + +import ( + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + stackpathLiveTest bool + stackpathClientID string + stackpathClientSecret string + stackpathStackID string + stackpathDomain string +) + +func init() { + stackpathClientID = os.Getenv("STACKPATH_CLIENT_ID") + stackpathClientSecret = os.Getenv("STACKPATH_CLIENT_SECRET") + stackpathStackID = os.Getenv("STACKPATH_STACK_ID") + stackpathDomain = os.Getenv("STACKPATH_DOMAIN") + + if len(stackpathClientID) > 0 && + len(stackpathClientSecret) > 0 && + len(stackpathStackID) > 0 && + len(stackpathDomain) > 0 { + stackpathLiveTest = true + } +} + +func restoreEnv() { + os.Setenv("STACKPATH_CLIENT_ID", stackpathClientID) + os.Setenv("STACKPATH_CLIENT_SECRET", stackpathClientSecret) + os.Setenv("STACKPATH_STACK_ID", stackpathStackID) + os.Setenv("STACKPATH_DOMAIN", stackpathDomain) +} + +func TestLivePresent(t *testing.T) { + if !stackpathLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProvider() + require.NoError(t, err) + + err = provider.Present(stackpathDomain, "", "123d==") + require.NoError(t, err) +} + +func TestLiveCleanUp(t *testing.T) { + if !stackpathLiveTest { + t.Skip("skipping live test") + } + + time.Sleep(time.Second * 1) + + provider, err := NewDNSProvider() + require.NoError(t, err) + + err = provider.CleanUp(stackpathDomain, "", "123d==") + require.NoError(t, err) +} + +func TestNewDNSProvider(t *testing.T) { + testCases := []struct { + desc string + envVars map[string]string + expected string + }{ + { + desc: "success", + envVars: map[string]string{ + "STACKPATH_CLIENT_ID": "test@example.com", + "STACKPATH_CLIENT_SECRET": "123", + "STACKPATH_STACK_ID": "ID", + }, + }, + { + desc: "missing credentials", + envVars: map[string]string{ + "STACKPATH_CLIENT_ID": "", + "STACKPATH_CLIENT_SECRET": "", + "STACKPATH_STACK_ID": "", + }, + expected: "stackpath: some credentials information are missing: STACKPATH_CLIENT_ID,STACKPATH_CLIENT_SECRET,STACKPATH_STACK_ID", + }, + { + desc: "missing client id", + envVars: map[string]string{ + "STACKPATH_CLIENT_ID": "", + "STACKPATH_CLIENT_SECRET": "123", + "STACKPATH_STACK_ID": "ID", + }, + expected: "stackpath: some credentials information are missing: STACKPATH_CLIENT_ID", + }, + { + desc: "missing client secret", + envVars: map[string]string{ + "STACKPATH_CLIENT_ID": "test@example.com", + "STACKPATH_CLIENT_SECRET": "", + "STACKPATH_STACK_ID": "ID", + }, + expected: "stackpath: some credentials information are missing: STACKPATH_CLIENT_SECRET", + }, + { + desc: "missing stack id", + envVars: map[string]string{ + "STACKPATH_CLIENT_ID": "test@example.com", + "STACKPATH_CLIENT_SECRET": "123", + "STACKPATH_STACK_ID": "", + }, + expected: "stackpath: some credentials information are missing: STACKPATH_STACK_ID", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + defer restoreEnv() + for key, value := range test.envVars { + if len(value) == 0 { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + } + + p, err := NewDNSProvider() + + if len(test.expected) == 0 { + assert.NoError(t, err) + assert.NotNil(t, p) + } else { + require.EqualError(t, err, test.expected) + } + }) + } +} + +func TestNewDNSProviderConfig(t *testing.T) { + testCases := map[string]struct { + config *Config + expectedErr string + }{ + "no_config": { + config: nil, + expectedErr: "stackpath: the configuration of the DNS provider is nil", + }, + "no_client_id": { + config: &Config{ + ClientSecret: "secret", + StackID: "stackID", + }, + expectedErr: "stackpath: credentials missing", + }, + "no_client_secret": { + config: &Config{ + ClientID: "clientID", + StackID: "stackID", + }, + expectedErr: "stackpath: credentials missing", + }, + "no_stack_id": { + config: &Config{ + ClientID: "clientID", + ClientSecret: "secret", + }, + expectedErr: "stackpath: stack id missing", + }, + } + + for desc, test := range testCases { + test := test + t.Run(desc, func(t *testing.T) { + t.Parallel() + + p, err := NewDNSProviderConfig(test.config) + require.EqualError(t, err, test.expectedErr) + assert.Nil(t, p) + }) + } +} + +func setupMockAPITest() (*DNSProvider, *http.ServeMux, func()) { + apiHandler := http.NewServeMux() + server := httptest.NewServer(apiHandler) + + config := NewDefaultConfig() + config.ClientID = "CLIENT_ID" + config.ClientSecret = "CLIENT_SECRET" + config.StackID = "STACK_ID" + + provider, err := NewDNSProviderConfig(config) + if err != nil { + panic(err) + } + + provider.client = http.DefaultClient + provider.BaseURL, _ = url.Parse(server.URL + "/") + + return provider, apiHandler, server.Close +} + +func TestDNSProvider_getZoneRecords(t *testing.T) { + provider, mux, tearDown := setupMockAPITest() + defer tearDown() + + mux.HandleFunc("/STACK_ID/zones/A/records", func(w http.ResponseWriter, req *http.Request) { + content := ` + { + "records": [ + {"id":"1","name":"foo1","type":"TXT","ttl":120,"data":"txtTXTtxt"}, + {"id":"2","name":"foo2","type":"TXT","ttl":121,"data":"TXTtxtTXT"} + ] + }` + + _, err := w.Write([]byte(content)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + + records, err := provider.getZoneRecords("foo1", &Zone{ID: "A", Domain: "test"}) + require.NoError(t, err) + + expected := []Record{ + {ID: "1", Name: "foo1", Type: "TXT", TTL: 120, Data: "txtTXTtxt"}, + {ID: "2", Name: "foo2", Type: "TXT", TTL: 121, Data: "TXTtxtTXT"}, + } + + assert.Equal(t, expected, records) +} + +func TestDNSProvider_getZoneRecords_apiError(t *testing.T) { + provider, mux, tearDown := setupMockAPITest() + defer tearDown() + + mux.HandleFunc("/STACK_ID/zones/A/records", func(w http.ResponseWriter, req *http.Request) { + content := ` +{ + "code": 401, + "error": "an unauthorized request is attempted." +}` + + w.WriteHeader(http.StatusUnauthorized) + _, err := w.Write([]byte(content)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + + _, err := provider.getZoneRecords("foo1", &Zone{ID: "A", Domain: "test"}) + + expected := &ErrorResponse{Code: 401, Message: "an unauthorized request is attempted."} + assert.Equal(t, expected, err) +} + +func TestDNSProvider_getZones(t *testing.T) { + provider, mux, tearDown := setupMockAPITest() + defer tearDown() + + mux.HandleFunc("/STACK_ID/zones", func(w http.ResponseWriter, req *http.Request) { + content := ` +{ + "pageInfo": { + "totalCount": "5", + "hasPreviousPage": false, + "hasNextPage": false, + "startCursor": "1", + "endCursor": "1" + }, + "zones": [ + { + "stackId": "my_stack", + "accountId": "my_account", + "id": "A", + "domain": "foo.com", + "version": "1", + "labels": { + "property1": "val1", + "property2": "val2" + }, + "created": "2018-10-07T02:31:49Z", + "updated": "2018-10-07T02:31:49Z", + "nameservers": [ + "1.1.1.1" + ], + "verified": "2018-10-07T02:31:49Z", + "status": "ACTIVE", + "disabled": false + } + ] +}` + + _, err := w.Write([]byte(content)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + + zone, err := provider.getZones("sub.foo.com") + require.NoError(t, err) + + expected := &Zone{ID: "A", Domain: "foo.com"} + + assert.Equal(t, expected, zone) +} diff --git a/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go b/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go new file mode 100644 index 00000000..c4e840d2 --- /dev/null +++ b/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go @@ -0,0 +1,109 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package clientcredentials implements the OAuth2.0 "client credentials" token flow, +// also known as the "two-legged OAuth 2.0". +// +// This should be used when the client is acting on its own behalf or when the client +// is the resource owner. It may also be used when requesting access to protected +// resources based on an authorization previously arranged with the authorization +// server. +// +// See https://tools.ietf.org/html/rfc6749#section-4.4 +package clientcredentials // import "golang.org/x/oauth2/clientcredentials" + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" + "golang.org/x/oauth2" + "golang.org/x/oauth2/internal" +) + +// Config describes a 2-legged OAuth2 flow, with both the +// client application information and the server's endpoint URLs. +type Config struct { + // ClientID is the application's ID. + ClientID string + + // ClientSecret is the application's secret. + ClientSecret string + + // TokenURL is the resource server's token endpoint + // URL. This is a constant specific to each server. + TokenURL string + + // Scope specifies optional requested permissions. + Scopes []string + + // EndpointParams specifies additional parameters for requests to the token endpoint. + EndpointParams url.Values +} + +// Token uses client credentials to retrieve a token. +// The HTTP client to use is derived from the context. +// If nil, http.DefaultClient is used. +func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { + return c.TokenSource(ctx).Token() +} + +// Client returns an HTTP client using the provided token. +// The token will auto-refresh as necessary. The underlying +// HTTP transport will be obtained using the provided context. +// The returned client and its Transport should not be modified. +func (c *Config) Client(ctx context.Context) *http.Client { + return oauth2.NewClient(ctx, c.TokenSource(ctx)) +} + +// TokenSource returns a TokenSource that returns t until t expires, +// automatically refreshing it as necessary using the provided context and the +// client ID and client secret. +// +// Most users will use Config.Client instead. +func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { + source := &tokenSource{ + ctx: ctx, + conf: c, + } + return oauth2.ReuseTokenSource(nil, source) +} + +type tokenSource struct { + ctx context.Context + conf *Config +} + +// Token refreshes the token by using a new client credentials request. +// tokens received this way do not include a refresh token +func (c *tokenSource) Token() (*oauth2.Token, error) { + v := url.Values{ + "grant_type": {"client_credentials"}, + } + if len(c.conf.Scopes) > 0 { + v.Set("scope", strings.Join(c.conf.Scopes, " ")) + } + for k, p := range c.conf.EndpointParams { + if _, ok := v[k]; ok { + return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k) + } + v[k] = p + } + tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v) + if err != nil { + if rErr, ok := err.(*internal.RetrieveError); ok { + return nil, (*oauth2.RetrieveError)(rErr) + } + return nil, err + } + t := &oauth2.Token{ + AccessToken: tk.AccessToken, + TokenType: tk.TokenType, + RefreshToken: tk.RefreshToken, + Expiry: tk.Expiry, + } + return t.WithExtra(tk.Raw), nil +}