forked from TrueCloudLab/certificates
Merge branch 'master' into herman/acme-da-tpm
This commit is contained in:
commit
dfc56f21b8
12 changed files with 134 additions and 419 deletions
|
@ -26,7 +26,12 @@ import (
|
||||||
const azureOIDCBaseURL = "https://login.microsoftonline.com"
|
const azureOIDCBaseURL = "https://login.microsoftonline.com"
|
||||||
|
|
||||||
//nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance.
|
//nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance.
|
||||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
|
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||||
|
|
||||||
|
const azureIdentityTokenAPIVersion = "2018-02-01"
|
||||||
|
|
||||||
|
// azureInstanceComputeURL is the URL to get the instance compute metadata.
|
||||||
|
const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment"
|
||||||
|
|
||||||
// azureDefaultAudience is the default audience used.
|
// azureDefaultAudience is the default audience used.
|
||||||
const azureDefaultAudience = "https://management.azure.com/"
|
const azureDefaultAudience = "https://management.azure.com/"
|
||||||
|
@ -35,15 +40,27 @@ const azureDefaultAudience = "https://management.azure.com/"
|
||||||
// Using case insensitive as resourceGroups appears as resourcegroups.
|
// Using case insensitive as resourceGroups appears as resourcegroups.
|
||||||
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
|
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
|
||||||
|
|
||||||
|
// azureEnvironments is the list of all Azure environments.
|
||||||
|
var azureEnvironments = map[string]string{
|
||||||
|
"AzurePublicCloud": "https://management.azure.com/",
|
||||||
|
"AzureCloud": "https://management.azure.com/",
|
||||||
|
"AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/",
|
||||||
|
"AzureUSGovernment": "https://management.usgovcloudapi.net/",
|
||||||
|
"AzureChinaCloud": "https://management.chinacloudapi.cn/",
|
||||||
|
"AzureGermanCloud": "https://management.microsoftazure.de/",
|
||||||
|
}
|
||||||
|
|
||||||
type azureConfig struct {
|
type azureConfig struct {
|
||||||
oidcDiscoveryURL string
|
oidcDiscoveryURL string
|
||||||
identityTokenURL string
|
identityTokenURL string
|
||||||
|
instanceComputeURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAzureConfig(tenantID string) *azureConfig {
|
func newAzureConfig(tenantID string) *azureConfig {
|
||||||
return &azureConfig{
|
return &azureConfig{
|
||||||
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
||||||
identityTokenURL: azureIdentityTokenURL,
|
identityTokenURL: azureIdentityTokenURL,
|
||||||
|
instanceComputeURL: azureInstanceComputeURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,6 +120,7 @@ type Azure struct {
|
||||||
oidcConfig openIDConfiguration
|
oidcConfig openIDConfiguration
|
||||||
keyStore *keyStore
|
keyStore *keyStore
|
||||||
ctl *Controller
|
ctl *Controller
|
||||||
|
environment string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier.
|
// GetID returns the provisioner unique identifier.
|
||||||
|
@ -167,11 +185,30 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
|
||||||
// Initialize the config if this method is used from the cli.
|
// Initialize the config if this method is used from the cli.
|
||||||
p.assertConfig()
|
p.assertConfig()
|
||||||
|
|
||||||
|
// default to AzurePublicCloud to keep existing behavior
|
||||||
|
identityTokenResource := azureEnvironments["AzurePublicCloud"]
|
||||||
|
|
||||||
|
var err error
|
||||||
|
p.environment, err = p.getAzureEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error getting azure environment")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resource, ok := azureEnvironments[p.environment]; ok {
|
||||||
|
identityTokenResource = resource
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
|
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "error creating request")
|
return "", errors.Wrap(err, "error creating request")
|
||||||
}
|
}
|
||||||
req.Header.Set("Metadata", "true")
|
req.Header.Set("Metadata", "true")
|
||||||
|
|
||||||
|
query := req.URL.Query()
|
||||||
|
query.Add("resource", identityTokenResource)
|
||||||
|
query.Add("api-version", azureIdentityTokenAPIVersion)
|
||||||
|
req.URL.RawQuery = query.Encode()
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
|
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
|
||||||
|
@ -444,3 +481,37 @@ func (p *Azure) assertConfig() {
|
||||||
p.config = newAzureConfig(p.TenantID)
|
p.config = newAzureConfig(p.TenantID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getAzureEnvironment returns the Azure environment for the current instance
|
||||||
|
func (p *Azure) getAzureEnvironment() (string, error) {
|
||||||
|
if p.environment != "" {
|
||||||
|
return p.environment, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error creating request")
|
||||||
|
}
|
||||||
|
req.Header.Add("Metadata", "True")
|
||||||
|
|
||||||
|
query := req.URL.Query()
|
||||||
|
query.Add("format", "text")
|
||||||
|
query.Add("api-version", "2021-02-01")
|
||||||
|
req.URL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error reading azure environment response")
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
|
@ -100,7 +100,14 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
||||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
wantResource := r.URL.Query().Get("want_resource")
|
||||||
|
resource := r.URL.Query().Get("resource")
|
||||||
|
if wantResource == "" || resource != wantResource {
|
||||||
|
http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/bad-request":
|
case "/bad-request":
|
||||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||||
|
@ -111,29 +118,58 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
||||||
fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
|
fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srvIdentity.Close()
|
||||||
|
|
||||||
|
srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/bad-request":
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||||
|
case "/AzureChinaCloud":
|
||||||
|
w.Header().Add("Content-Type", "text/plain")
|
||||||
|
w.Write([]byte("AzureChinaCloud"))
|
||||||
|
case "/AzureGermanCloud":
|
||||||
|
w.Header().Add("Content-Type", "text/plain")
|
||||||
|
w.Write([]byte("AzureGermanCloud"))
|
||||||
|
case "/AzureUSGovernmentCloud":
|
||||||
|
w.Header().Add("Content-Type", "text/plain")
|
||||||
|
w.Write([]byte("AzureUSGovernmentCloud"))
|
||||||
|
default:
|
||||||
|
w.Header().Add("Content-Type", "text/plain")
|
||||||
|
w.Write([]byte("AzurePublicCloud"))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srvInstance.Close()
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
subject string
|
subject string
|
||||||
caURL string
|
caURL string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
azure *Azure
|
azure *Azure
|
||||||
args args
|
args args
|
||||||
identityTokenURL string
|
identityTokenURL string
|
||||||
want string
|
instanceComputeURL string
|
||||||
wantErr bool
|
wantEnvironment string
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false},
|
{"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
|
||||||
{"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true},
|
{"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
|
||||||
{"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true},
|
{"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false},
|
||||||
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true},
|
{"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false},
|
||||||
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true},
|
{"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true},
|
||||||
|
{"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||||
|
{"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||||
|
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||||
|
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tt.azure.config.identityTokenURL = tt.identityTokenURL
|
// reset environment between tests to avoid caching issues
|
||||||
|
p1.environment = ""
|
||||||
|
tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment]
|
||||||
|
tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment
|
||||||
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
|
@ -338,8 +338,6 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||||
case emailOnlyIdentity:
|
|
||||||
assert.Equals(t, string(v), "name@smallstep.com")
|
|
||||||
case *x509NamePolicyValidator:
|
case *x509NamePolicyValidator:
|
||||||
assert.Equals(t, nil, v.policyEngine)
|
assert.Equals(t, nil, v.policyEngine)
|
||||||
case *WebhookController:
|
case *WebhookController:
|
||||||
|
|
|
@ -83,31 +83,6 @@ type AttestationData struct {
|
||||||
PermanentIdentifier string
|
PermanentIdentifier string
|
||||||
}
|
}
|
||||||
|
|
||||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
|
||||||
// SAN provided is the given email address.
|
|
||||||
type emailOnlyIdentity string
|
|
||||||
|
|
||||||
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
|
||||||
switch {
|
|
||||||
case len(req.DNSNames) > 0:
|
|
||||||
return errs.Forbidden("certificate request cannot contain DNS names")
|
|
||||||
case len(req.IPAddresses) > 0:
|
|
||||||
return errs.Forbidden("certificate request cannot contain IP addresses")
|
|
||||||
case len(req.URIs) > 0:
|
|
||||||
return errs.Forbidden("certificate request cannot contain URIs")
|
|
||||||
case len(req.EmailAddresses) == 0:
|
|
||||||
return errs.Forbidden("certificate request does not contain any email address")
|
|
||||||
case len(req.EmailAddresses) > 1:
|
|
||||||
return errs.Forbidden("certificate request contains too many email addresses")
|
|
||||||
case req.EmailAddresses[0] == "":
|
|
||||||
return errs.Forbidden("certificate request cannot contain an empty email address")
|
|
||||||
case req.EmailAddresses[0] != string(e):
|
|
||||||
return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e)
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultPublicKeyValidator validates the public key of a certificate request.
|
// defaultPublicKeyValidator validates the public key of a certificate request.
|
||||||
type defaultPublicKeyValidator struct{}
|
type defaultPublicKeyValidator struct{}
|
||||||
|
|
||||||
|
|
|
@ -16,38 +16,6 @@ import (
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_emailOnlyIdentity_Valid(t *testing.T) {
|
|
||||||
uri, err := url.Parse("https://example.com/1.0/getUser")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *x509.CertificateRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
e emailOnlyIdentity
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
|
|
||||||
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
|
|
||||||
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
|
|
||||||
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
|
|
||||||
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
|
|
||||||
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
|
|
||||||
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
|
func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
|
||||||
_shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
|
_shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
|
@ -125,35 +125,6 @@ func (o SignSSHOptions) match(got SignSSHOptions) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
|
|
||||||
// principals to the SSH certificate.
|
|
||||||
type sshCertPrincipalsModifier []string
|
|
||||||
|
|
||||||
// Modify the ValidPrincipals value of the cert.
|
|
||||||
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
|
||||||
cert.ValidPrincipals = []string(o)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
|
|
||||||
// Key ID in the SSH certificate.
|
|
||||||
type sshCertKeyIDModifier string
|
|
||||||
|
|
||||||
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
|
||||||
cert.KeyId = string(m)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshCertTypeModifier is an SSHCertModifier that sets the
|
|
||||||
// certificate type.
|
|
||||||
type sshCertTypeModifier string
|
|
||||||
|
|
||||||
// Modify sets the CertType for the ssh certificate.
|
|
||||||
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
|
||||||
cert.CertType = sshCertTypeUInt32(string(m))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshCertValidAfterModifier is an SSHCertModifier that sets the
|
// sshCertValidAfterModifier is an SSHCertModifier that sets the
|
||||||
// ValidAfter in the SSH certificate.
|
// ValidAfter in the SSH certificate.
|
||||||
type sshCertValidAfterModifier uint64
|
type sshCertValidAfterModifier uint64
|
||||||
|
@ -172,51 +143,6 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptio
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sshCertDefaultsModifier implements a SSHCertModifier that
|
|
||||||
// modifies the certificate with the given options if they are not set.
|
|
||||||
type sshCertDefaultsModifier SignSSHOptions
|
|
||||||
|
|
||||||
// Modify implements the SSHCertModifier interface.
|
|
||||||
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
|
||||||
if cert.CertType == 0 {
|
|
||||||
cert.CertType = sshCertTypeUInt32(m.CertType)
|
|
||||||
}
|
|
||||||
if len(cert.ValidPrincipals) == 0 {
|
|
||||||
cert.ValidPrincipals = m.Principals
|
|
||||||
}
|
|
||||||
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
|
|
||||||
cert.ValidAfter = uint64(m.ValidAfter.Unix())
|
|
||||||
}
|
|
||||||
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
|
|
||||||
cert.ValidBefore = uint64(m.ValidBefore.Unix())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
|
|
||||||
// the default extensions in an SSH certificate.
|
|
||||||
type sshDefaultExtensionModifier struct{}
|
|
||||||
|
|
||||||
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
|
||||||
switch cert.CertType {
|
|
||||||
// Default to no extensions for HostCert.
|
|
||||||
case ssh.HostCert:
|
|
||||||
return nil
|
|
||||||
case ssh.UserCert:
|
|
||||||
if cert.Extensions == nil {
|
|
||||||
cert.Extensions = make(map[string]string)
|
|
||||||
}
|
|
||||||
cert.Extensions["permit-X11-forwarding"] = ""
|
|
||||||
cert.Extensions["permit-agent-forwarding"] = ""
|
|
||||||
cert.Extensions["permit-port-forwarding"] = ""
|
|
||||||
cert.Extensions["permit-pty"] = ""
|
|
||||||
cert.Extensions["permit-user-rc"] = ""
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshDefaultDuration is an SSHCertModifier that sets the certificate
|
// sshDefaultDuration is an SSHCertModifier that sets the certificate
|
||||||
// ValidAfter and ValidBefore if they have not been set. It will fail if a
|
// ValidAfter and ValidBefore if they have not been set. It will fail if a
|
||||||
// CertType has not been set or is not valid.
|
// CertType has not been set or is not valid.
|
||||||
|
|
|
@ -202,97 +202,6 @@ func TestSSHOptions_Match(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
modifier sshCertPrincipalsModifier
|
|
||||||
cert *ssh.Certificate
|
|
||||||
expected []string
|
|
||||||
}
|
|
||||||
tests := map[string]func() test{
|
|
||||||
"ok": func() test {
|
|
||||||
a := []string{"foo", "bar"}
|
|
||||||
return test{
|
|
||||||
modifier: sshCertPrincipalsModifier(a),
|
|
||||||
cert: new(ssh.Certificate),
|
|
||||||
expected: a,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run()
|
|
||||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
||||||
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
modifier sshCertKeyIDModifier
|
|
||||||
cert *ssh.Certificate
|
|
||||||
expected string
|
|
||||||
}
|
|
||||||
tests := map[string]func() test{
|
|
||||||
"ok": func() test {
|
|
||||||
a := "foo"
|
|
||||||
return test{
|
|
||||||
modifier: sshCertKeyIDModifier(a),
|
|
||||||
cert: new(ssh.Certificate),
|
|
||||||
expected: a,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run()
|
|
||||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
||||||
assert.Equals(t, tc.cert.KeyId, tc.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
modifier sshCertTypeModifier
|
|
||||||
cert *ssh.Certificate
|
|
||||||
expected uint32
|
|
||||||
}
|
|
||||||
tests := map[string]func() test{
|
|
||||||
"ok/user": func() test {
|
|
||||||
return test{
|
|
||||||
modifier: sshCertTypeModifier("user"),
|
|
||||||
cert: new(ssh.Certificate),
|
|
||||||
expected: ssh.UserCert,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/host": func() test {
|
|
||||||
return test{
|
|
||||||
modifier: sshCertTypeModifier("host"),
|
|
||||||
cert: new(ssh.Certificate),
|
|
||||||
expected: ssh.HostCert,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/default": func() test {
|
|
||||||
return test{
|
|
||||||
modifier: sshCertTypeModifier("foo"),
|
|
||||||
cert: new(ssh.Certificate),
|
|
||||||
expected: 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run()
|
|
||||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
||||||
assert.Equals(t, tc.cert.CertType, tc.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
modifier sshCertValidAfterModifier
|
modifier sshCertValidAfterModifier
|
||||||
|
@ -318,176 +227,6 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
modifier sshCertDefaultsModifier
|
|
||||||
cert *ssh.Certificate
|
|
||||||
valid func(*ssh.Certificate)
|
|
||||||
}
|
|
||||||
tests := map[string]func() test{
|
|
||||||
"ok/changes": func() test {
|
|
||||||
n := time.Now()
|
|
||||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
|
||||||
vb := NewTimeDuration(n.Add(5 * time.Minute))
|
|
||||||
so := SignSSHOptions{
|
|
||||||
Principals: []string{"foo", "bar"},
|
|
||||||
CertType: "host",
|
|
||||||
ValidAfter: va,
|
|
||||||
ValidBefore: vb,
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
modifier: sshCertDefaultsModifier(so),
|
|
||||||
cert: new(ssh.Certificate),
|
|
||||||
valid: func(cert *ssh.Certificate) {
|
|
||||||
assert.Equals(t, cert.ValidPrincipals, so.Principals)
|
|
||||||
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
|
|
||||||
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
|
|
||||||
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/no-changes": func() test {
|
|
||||||
n := time.Now()
|
|
||||||
so := SignSSHOptions{
|
|
||||||
Principals: []string{"foo", "bar"},
|
|
||||||
CertType: "host",
|
|
||||||
ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)),
|
|
||||||
ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
modifier: sshCertDefaultsModifier(so),
|
|
||||||
cert: &ssh.Certificate{
|
|
||||||
CertType: uint32(ssh.UserCert),
|
|
||||||
ValidPrincipals: []string{"zap", "zoop"},
|
|
||||||
ValidAfter: 15,
|
|
||||||
ValidBefore: 25,
|
|
||||||
},
|
|
||||||
valid: func(cert *ssh.Certificate) {
|
|
||||||
assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
|
|
||||||
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
|
|
||||||
assert.Equals(t, cert.ValidAfter, uint64(15))
|
|
||||||
assert.Equals(t, cert.ValidBefore, uint64(25))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run()
|
|
||||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
||||||
tc.valid(tc.cert)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
modifier sshDefaultExtensionModifier
|
|
||||||
cert *ssh.Certificate
|
|
||||||
valid func(*ssh.Certificate)
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
tests := map[string]func() test{
|
|
||||||
"fail/unexpected-cert-type": func() test {
|
|
||||||
cert := &ssh.Certificate{CertType: 3}
|
|
||||||
return test{
|
|
||||||
modifier: sshDefaultExtensionModifier{},
|
|
||||||
cert: cert,
|
|
||||||
err: errors.New("ssh certificate has an unknown type '3'"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/host": func() test {
|
|
||||||
cert := &ssh.Certificate{CertType: ssh.HostCert}
|
|
||||||
return test{
|
|
||||||
modifier: sshDefaultExtensionModifier{},
|
|
||||||
cert: cert,
|
|
||||||
valid: func(cert *ssh.Certificate) {
|
|
||||||
assert.Len(t, 0, cert.Extensions)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/user/extensions-exists": func() test {
|
|
||||||
cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
|
|
||||||
"foo": "bar",
|
|
||||||
}}}
|
|
||||||
return test{
|
|
||||||
modifier: sshDefaultExtensionModifier{},
|
|
||||||
cert: cert,
|
|
||||||
valid: func(cert *ssh.Certificate) {
|
|
||||||
val, ok := cert.Extensions["foo"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "bar")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-X11-forwarding"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-agent-forwarding"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-port-forwarding"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-pty"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-user-rc"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/user/no-extensions": func() test {
|
|
||||||
return test{
|
|
||||||
modifier: sshDefaultExtensionModifier{},
|
|
||||||
cert: &ssh.Certificate{CertType: ssh.UserCert},
|
|
||||||
valid: func(cert *ssh.Certificate) {
|
|
||||||
_, ok := cert.Extensions["foo"]
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
val, ok := cert.Extensions["permit-X11-forwarding"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-agent-forwarding"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-port-forwarding"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-pty"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
|
|
||||||
val, ok = cert.Extensions["permit-user-rc"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equals(t, val, "")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run()
|
|
||||||
if err := tc.modifier.Modify(tc.cert, SignSSHOptions{}); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
tc.valid(tc.cert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
||||||
pub, _, err := keyutil.GenerateDefaultKeyPair()
|
pub, _, err := keyutil.GenerateDefaultKeyPair()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
|
@ -665,6 +665,9 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||||
AccessToken: tok,
|
AccessToken: tok,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
case "/metadata/instance/compute/azEnvironment":
|
||||||
|
w.Header().Add("Content-Type", "text/plain")
|
||||||
|
w.Write([]byte("AzurePublicCloud"))
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
|
@ -672,6 +675,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||||
srv.Start()
|
srv.Start()
|
||||||
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
|
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
|
||||||
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
|
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
|
||||||
|
az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment"
|
||||||
return az, srv, nil
|
return az, srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -790,8 +790,6 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
||||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
|
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
|
||||||
case sshCertValidBeforeModifier:
|
case sshCertValidBeforeModifier:
|
||||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
|
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
|
||||||
case sshCertDefaultsModifier:
|
|
||||||
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
|
|
||||||
case *sshLimitDuration:
|
case *sshLimitDuration:
|
||||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||||
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
||||||
|
|
|
@ -85,7 +85,7 @@ Requires **--insecure** flag.`,
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "pidfile",
|
Name: "pidfile",
|
||||||
Usage: "that path to the <file> to write the process ID.",
|
Usage: "the path to the <file> to write the process ID.",
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "insecure",
|
Name: "insecure",
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -18,7 +18,7 @@ require (
|
||||||
github.com/hashicorp/vault/api/auth/approle v0.4.0
|
github.com/hashicorp/vault/api/auth/approle v0.4.0
|
||||||
github.com/hashicorp/vault/api/auth/kubernetes v0.4.0
|
github.com/hashicorp/vault/api/auth/kubernetes v0.4.0
|
||||||
github.com/micromdm/scep/v2 v2.1.0
|
github.com/micromdm/scep/v2 v2.1.0
|
||||||
github.com/newrelic/go-agent/v3 v3.20.4
|
github.com/newrelic/go-agent/v3 v3.21.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/rs/xid v1.4.0
|
github.com/rs/xid v1.4.0
|
||||||
github.com/sirupsen/logrus v1.9.0
|
github.com/sirupsen/logrus v1.9.0
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -742,8 +742,8 @@ github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzE
|
||||||
github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
||||||
github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
||||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||||
github.com/newrelic/go-agent/v3 v3.20.4 h1:fkxr0oUEYrPeXyfJC0D0BwDs1FYMe4NgUSqnzqPESI0=
|
github.com/newrelic/go-agent/v3 v3.21.0 h1:KpkoW6PnSVzEDEO0W/C9LZEZZGwAb+a9g5DN8ifvt4Y=
|
||||||
github.com/newrelic/go-agent/v3 v3.20.4/go.mod h1:rT6ZUxJc5rQbWLyCtjqQCOcfb01lKRFbc1yMQkcboWM=
|
github.com/newrelic/go-agent/v3 v3.21.0/go.mod h1:rT6ZUxJc5rQbWLyCtjqQCOcfb01lKRFbc1yMQkcboWM=
|
||||||
github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso=
|
github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso=
|
||||||
github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs=
|
github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs=
|
||||||
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
|
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
|
||||||
|
|
Loading…
Reference in a new issue