forked from TrueCloudLab/certificates
Merge branch 'master' into herman/acme-da-tpm
This commit is contained in:
commit
dfc56f21b8
12 changed files with 134 additions and 419 deletions
|
@ -26,7 +26,12 @@ import (
|
|||
const azureOIDCBaseURL = "https://login.microsoftonline.com"
|
||||
|
||||
//nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance.
|
||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
|
||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||
|
||||
const azureIdentityTokenAPIVersion = "2018-02-01"
|
||||
|
||||
// azureInstanceComputeURL is the URL to get the instance compute metadata.
|
||||
const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment"
|
||||
|
||||
// azureDefaultAudience is the default audience used.
|
||||
const azureDefaultAudience = "https://management.azure.com/"
|
||||
|
@ -35,15 +40,27 @@ const azureDefaultAudience = "https://management.azure.com/"
|
|||
// Using case insensitive as resourceGroups appears as resourcegroups.
|
||||
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
|
||||
|
||||
// azureEnvironments is the list of all Azure environments.
|
||||
var azureEnvironments = map[string]string{
|
||||
"AzurePublicCloud": "https://management.azure.com/",
|
||||
"AzureCloud": "https://management.azure.com/",
|
||||
"AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/",
|
||||
"AzureUSGovernment": "https://management.usgovcloudapi.net/",
|
||||
"AzureChinaCloud": "https://management.chinacloudapi.cn/",
|
||||
"AzureGermanCloud": "https://management.microsoftazure.de/",
|
||||
}
|
||||
|
||||
type azureConfig struct {
|
||||
oidcDiscoveryURL string
|
||||
identityTokenURL string
|
||||
oidcDiscoveryURL string
|
||||
identityTokenURL string
|
||||
instanceComputeURL string
|
||||
}
|
||||
|
||||
func newAzureConfig(tenantID string) *azureConfig {
|
||||
return &azureConfig{
|
||||
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
||||
identityTokenURL: azureIdentityTokenURL,
|
||||
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
||||
identityTokenURL: azureIdentityTokenURL,
|
||||
instanceComputeURL: azureInstanceComputeURL,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,6 +120,7 @@ type Azure struct {
|
|||
oidcConfig openIDConfiguration
|
||||
keyStore *keyStore
|
||||
ctl *Controller
|
||||
environment string
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
|
@ -167,11 +185,30 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
|
|||
// Initialize the config if this method is used from the cli.
|
||||
p.assertConfig()
|
||||
|
||||
// default to AzurePublicCloud to keep existing behavior
|
||||
identityTokenResource := azureEnvironments["AzurePublicCloud"]
|
||||
|
||||
var err error
|
||||
p.environment, err = p.getAzureEnvironment()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error getting azure environment")
|
||||
}
|
||||
|
||||
if resource, ok := azureEnvironments[p.environment]; ok {
|
||||
identityTokenResource = resource
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error creating request")
|
||||
}
|
||||
req.Header.Set("Metadata", "true")
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Add("resource", identityTokenResource)
|
||||
query.Add("api-version", azureIdentityTokenAPIVersion)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
|
||||
|
@ -444,3 +481,37 @@ func (p *Azure) assertConfig() {
|
|||
p.config = newAzureConfig(p.TenantID)
|
||||
}
|
||||
}
|
||||
|
||||
// getAzureEnvironment returns the Azure environment for the current instance
|
||||
func (p *Azure) getAzureEnvironment() (string, error) {
|
||||
if p.environment != "" {
|
||||
return p.environment, nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error creating request")
|
||||
}
|
||||
req.Header.Add("Metadata", "True")
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Add("format", "text")
|
||||
query.Add("api-version", "2021-02-01")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error reading azure environment response")
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b)
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
|
|
@ -100,7 +100,14 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wantResource := r.URL.Query().Get("want_resource")
|
||||
resource := r.URL.Query().Get("resource")
|
||||
if wantResource == "" || resource != wantResource {
|
||||
http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/bad-request":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
|
@ -111,29 +118,58 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
defer srvIdentity.Close()
|
||||
|
||||
srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/bad-request":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/AzureChinaCloud":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzureChinaCloud"))
|
||||
case "/AzureGermanCloud":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzureGermanCloud"))
|
||||
case "/AzureUSGovernmentCloud":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzureUSGovernmentCloud"))
|
||||
default:
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzurePublicCloud"))
|
||||
}
|
||||
}))
|
||||
defer srvInstance.Close()
|
||||
|
||||
type args struct {
|
||||
subject string
|
||||
caURL string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
identityTokenURL string
|
||||
want string
|
||||
wantErr bool
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
identityTokenURL string
|
||||
instanceComputeURL string
|
||||
wantEnvironment string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false},
|
||||
{"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true},
|
||||
{"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true},
|
||||
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true},
|
||||
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true},
|
||||
{"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
|
||||
{"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
|
||||
{"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false},
|
||||
{"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false},
|
||||
{"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true},
|
||||
{"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
{"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.azure.config.identityTokenURL = tt.identityTokenURL
|
||||
// reset environment between tests to avoid caching issues
|
||||
p1.environment = ""
|
||||
tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment]
|
||||
tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment
|
||||
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
|
|
@ -338,8 +338,6 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||
case emailOnlyIdentity:
|
||||
assert.Equals(t, string(v), "name@smallstep.com")
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
case *WebhookController:
|
||||
|
|
|
@ -83,31 +83,6 @@ type AttestationData struct {
|
|||
PermanentIdentifier string
|
||||
}
|
||||
|
||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||
// SAN provided is the given email address.
|
||||
type emailOnlyIdentity string
|
||||
|
||||
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||
switch {
|
||||
case len(req.DNSNames) > 0:
|
||||
return errs.Forbidden("certificate request cannot contain DNS names")
|
||||
case len(req.IPAddresses) > 0:
|
||||
return errs.Forbidden("certificate request cannot contain IP addresses")
|
||||
case len(req.URIs) > 0:
|
||||
return errs.Forbidden("certificate request cannot contain URIs")
|
||||
case len(req.EmailAddresses) == 0:
|
||||
return errs.Forbidden("certificate request does not contain any email address")
|
||||
case len(req.EmailAddresses) > 1:
|
||||
return errs.Forbidden("certificate request contains too many email addresses")
|
||||
case req.EmailAddresses[0] == "":
|
||||
return errs.Forbidden("certificate request cannot contain an empty email address")
|
||||
case req.EmailAddresses[0] != string(e):
|
||||
return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// defaultPublicKeyValidator validates the public key of a certificate request.
|
||||
type defaultPublicKeyValidator struct{}
|
||||
|
||||
|
|
|
@ -16,38 +16,6 @@ import (
|
|||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func Test_emailOnlyIdentity_Valid(t *testing.T) {
|
||||
uri, err := url.Parse("https://example.com/1.0/getUser")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e emailOnlyIdentity
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
|
||||
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
|
||||
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
|
||||
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
|
||||
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
|
||||
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
|
||||
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
|
||||
_shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
|
||||
assert.FatalError(t, err)
|
||||
|
|
|
@ -125,35 +125,6 @@ func (o SignSSHOptions) match(got SignSSHOptions) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
|
||||
// principals to the SSH certificate.
|
||||
type sshCertPrincipalsModifier []string
|
||||
|
||||
// Modify the ValidPrincipals value of the cert.
|
||||
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||
cert.ValidPrincipals = []string(o)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
|
||||
// Key ID in the SSH certificate.
|
||||
type sshCertKeyIDModifier string
|
||||
|
||||
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||
cert.KeyId = string(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertTypeModifier is an SSHCertModifier that sets the
|
||||
// certificate type.
|
||||
type sshCertTypeModifier string
|
||||
|
||||
// Modify sets the CertType for the ssh certificate.
|
||||
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||
cert.CertType = sshCertTypeUInt32(string(m))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertValidAfterModifier is an SSHCertModifier that sets the
|
||||
// ValidAfter in the SSH certificate.
|
||||
type sshCertValidAfterModifier uint64
|
||||
|
@ -172,51 +143,6 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptio
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertDefaultsModifier implements a SSHCertModifier that
|
||||
// modifies the certificate with the given options if they are not set.
|
||||
type sshCertDefaultsModifier SignSSHOptions
|
||||
|
||||
// Modify implements the SSHCertModifier interface.
|
||||
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||
if cert.CertType == 0 {
|
||||
cert.CertType = sshCertTypeUInt32(m.CertType)
|
||||
}
|
||||
if len(cert.ValidPrincipals) == 0 {
|
||||
cert.ValidPrincipals = m.Principals
|
||||
}
|
||||
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
|
||||
cert.ValidAfter = uint64(m.ValidAfter.Unix())
|
||||
}
|
||||
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
|
||||
cert.ValidBefore = uint64(m.ValidBefore.Unix())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
|
||||
// the default extensions in an SSH certificate.
|
||||
type sshDefaultExtensionModifier struct{}
|
||||
|
||||
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||
switch cert.CertType {
|
||||
// Default to no extensions for HostCert.
|
||||
case ssh.HostCert:
|
||||
return nil
|
||||
case ssh.UserCert:
|
||||
if cert.Extensions == nil {
|
||||
cert.Extensions = make(map[string]string)
|
||||
}
|
||||
cert.Extensions["permit-X11-forwarding"] = ""
|
||||
cert.Extensions["permit-agent-forwarding"] = ""
|
||||
cert.Extensions["permit-port-forwarding"] = ""
|
||||
cert.Extensions["permit-pty"] = ""
|
||||
cert.Extensions["permit-user-rc"] = ""
|
||||
return nil
|
||||
default:
|
||||
return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
|
||||
}
|
||||
}
|
||||
|
||||
// sshDefaultDuration is an SSHCertModifier that sets the certificate
|
||||
// ValidAfter and ValidBefore if they have not been set. It will fail if a
|
||||
// CertType has not been set or is not valid.
|
||||
|
|
|
@ -202,97 +202,6 @@ func TestSSHOptions_Match(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertPrincipalsModifier
|
||||
cert *ssh.Certificate
|
||||
expected []string
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"ok": func() test {
|
||||
a := []string{"foo", "bar"}
|
||||
return test{
|
||||
modifier: sshCertPrincipalsModifier(a),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: a,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
||||
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertKeyIDModifier
|
||||
cert *ssh.Certificate
|
||||
expected string
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"ok": func() test {
|
||||
a := "foo"
|
||||
return test{
|
||||
modifier: sshCertKeyIDModifier(a),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: a,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
||||
assert.Equals(t, tc.cert.KeyId, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertTypeModifier
|
||||
cert *ssh.Certificate
|
||||
expected uint32
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"ok/user": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("user"),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: ssh.UserCert,
|
||||
}
|
||||
},
|
||||
"ok/host": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("host"),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: ssh.HostCert,
|
||||
}
|
||||
},
|
||||
"ok/default": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("foo"),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: 0,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
||||
assert.Equals(t, tc.cert.CertType, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertValidAfterModifier
|
||||
|
@ -318,176 +227,6 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertDefaultsModifier
|
||||
cert *ssh.Certificate
|
||||
valid func(*ssh.Certificate)
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"ok/changes": func() test {
|
||||
n := time.Now()
|
||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
||||
vb := NewTimeDuration(n.Add(5 * time.Minute))
|
||||
so := SignSSHOptions{
|
||||
Principals: []string{"foo", "bar"},
|
||||
CertType: "host",
|
||||
ValidAfter: va,
|
||||
ValidBefore: vb,
|
||||
}
|
||||
return test{
|
||||
modifier: sshCertDefaultsModifier(so),
|
||||
cert: new(ssh.Certificate),
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.ValidPrincipals, so.Principals)
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
|
||||
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
|
||||
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/no-changes": func() test {
|
||||
n := time.Now()
|
||||
so := SignSSHOptions{
|
||||
Principals: []string{"foo", "bar"},
|
||||
CertType: "host",
|
||||
ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)),
|
||||
ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
|
||||
}
|
||||
return test{
|
||||
modifier: sshCertDefaultsModifier(so),
|
||||
cert: &ssh.Certificate{
|
||||
CertType: uint32(ssh.UserCert),
|
||||
ValidPrincipals: []string{"zap", "zoop"},
|
||||
ValidAfter: 15,
|
||||
ValidBefore: 25,
|
||||
},
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
|
||||
assert.Equals(t, cert.ValidAfter, uint64(15))
|
||||
assert.Equals(t, cert.ValidBefore, uint64(25))
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
||||
tc.valid(tc.cert)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshDefaultExtensionModifier
|
||||
cert *ssh.Certificate
|
||||
valid func(*ssh.Certificate)
|
||||
err error
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"fail/unexpected-cert-type": func() test {
|
||||
cert := &ssh.Certificate{CertType: 3}
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: cert,
|
||||
err: errors.New("ssh certificate has an unknown type '3'"),
|
||||
}
|
||||
},
|
||||
"ok/host": func() test {
|
||||
cert := &ssh.Certificate{CertType: ssh.HostCert}
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: cert,
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Len(t, 0, cert.Extensions)
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/user/extensions-exists": func() test {
|
||||
cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
|
||||
"foo": "bar",
|
||||
}}}
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: cert,
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
val, ok := cert.Extensions["foo"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "bar")
|
||||
|
||||
val, ok = cert.Extensions["permit-X11-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-agent-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-port-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-pty"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-user-rc"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/user/no-extensions": func() test {
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: &ssh.Certificate{CertType: ssh.UserCert},
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
_, ok := cert.Extensions["foo"]
|
||||
assert.False(t, ok)
|
||||
|
||||
val, ok := cert.Extensions["permit-X11-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-agent-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-port-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-pty"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-user-rc"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if err := tc.modifier.Modify(tc.cert, SignSSHOptions{}); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
tc.valid(tc.cert)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
||||
pub, _, err := keyutil.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
|
|
|
@ -665,6 +665,9 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
|||
AccessToken: tok,
|
||||
})
|
||||
}
|
||||
case "/metadata/instance/compute/azEnvironment":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzurePublicCloud"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
@ -672,6 +675,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
|||
srv.Start()
|
||||
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
|
||||
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
|
||||
az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment"
|
||||
return az, srv, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -790,8 +790,6 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
|||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
|
||||
case sshCertValidBeforeModifier:
|
||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
|
||||
case sshCertDefaultsModifier:
|
||||
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
|
||||
case *sshLimitDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
||||
|
|
|
@ -85,7 +85,7 @@ Requires **--insecure** flag.`,
|
|||
},
|
||||
cli.StringFlag{
|
||||
Name: "pidfile",
|
||||
Usage: "that path to the <file> to write the process ID.",
|
||||
Usage: "the path to the <file> to write the process ID.",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "insecure",
|
||||
|
|
2
go.mod
2
go.mod
|
@ -18,7 +18,7 @@ require (
|
|||
github.com/hashicorp/vault/api/auth/approle v0.4.0
|
||||
github.com/hashicorp/vault/api/auth/kubernetes v0.4.0
|
||||
github.com/micromdm/scep/v2 v2.1.0
|
||||
github.com/newrelic/go-agent/v3 v3.20.4
|
||||
github.com/newrelic/go-agent/v3 v3.21.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rs/xid v1.4.0
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -742,8 +742,8 @@ github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzE
|
|||
github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
||||
github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
|
||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||
github.com/newrelic/go-agent/v3 v3.20.4 h1:fkxr0oUEYrPeXyfJC0D0BwDs1FYMe4NgUSqnzqPESI0=
|
||||
github.com/newrelic/go-agent/v3 v3.20.4/go.mod h1:rT6ZUxJc5rQbWLyCtjqQCOcfb01lKRFbc1yMQkcboWM=
|
||||
github.com/newrelic/go-agent/v3 v3.21.0 h1:KpkoW6PnSVzEDEO0W/C9LZEZZGwAb+a9g5DN8ifvt4Y=
|
||||
github.com/newrelic/go-agent/v3 v3.21.0/go.mod h1:rT6ZUxJc5rQbWLyCtjqQCOcfb01lKRFbc1yMQkcboWM=
|
||||
github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso=
|
||||
github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs=
|
||||
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
|
||||
|
|
Loading…
Reference in a new issue