diff --git a/authority/authorize_test.go b/authority/authorize_test.go index e4863764..1d7e69ad 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -449,7 +449,7 @@ func TestAuthority_authorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 8, got) + assert.Len(t, 6, got) } } }) diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 9e8fc7ad..3c4e3f86 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -10,6 +10,7 @@ import ( "encoding/hex" "encoding/pem" "fmt" + "net" "net/http" "net/url" "strings" @@ -529,7 +530,7 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) type args struct { - token string + token, cn string } tests := []struct { name string @@ -539,24 +540,24 @@ func TestAWS_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 7, http.StatusOK, false}, - {"ok", p2, args{t2Hostname}, 7, http.StatusOK, false}, - {"ok", p2, args{t2PrivateIP}, 7, http.StatusOK, false}, - {"ok", p1, args{t4}, 5, http.StatusOK, false}, - {"fail account", p3, args{t3}, 0, http.StatusUnauthorized, true}, - {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, - {"fail subject", p1, args{failSubject}, 0, http.StatusUnauthorized, true}, - {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true}, - {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true}, - {"fail account", p1, args{failAccount}, 0, http.StatusUnauthorized, true}, - {"fail instanceID", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true}, - {"fail privateIP", p1, args{failPrivateIP}, 0, http.StatusUnauthorized, true}, - {"fail region", p1, args{failRegion}, 0, http.StatusUnauthorized, true}, - {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true}, - {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true}, - {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, - {"fail instance age", p2, args{failInstanceAge}, 0, http.StatusUnauthorized, true}, + {"ok", p1, args{t1, "foo.local"}, 5, http.StatusOK, false}, + {"ok", p2, args{t2, "instance-id"}, 9, http.StatusOK, false}, + {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 9, http.StatusOK, false}, + {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 9, http.StatusOK, false}, + {"ok", p1, args{t4, "instance-id"}, 5, http.StatusOK, false}, + {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, + {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, + {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, + {"fail issuer", p1, args{token: failIssuer}, 0, http.StatusUnauthorized, true}, + {"fail audience", p1, args{token: failAudience}, 0, http.StatusUnauthorized, true}, + {"fail account", p1, args{token: failAccount}, 0, http.StatusUnauthorized, true}, + {"fail instanceID", p1, args{token: failInstanceID}, 0, http.StatusUnauthorized, true}, + {"fail privateIP", p1, args{token: failPrivateIP}, 0, http.StatusUnauthorized, true}, + {"fail region", p1, args{token: failRegion}, 0, http.StatusUnauthorized, true}, + {"fail exp", p1, args{token: failExp}, 0, http.StatusUnauthorized, true}, + {"fail nbf", p1, args{token: failNbf}, 0, http.StatusUnauthorized, true}, + {"fail key", p1, args{token: failKey}, 0, http.StatusUnauthorized, true}, + {"fail instance age", p2, args{token: failInstanceAge}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -571,6 +572,33 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, sc.StatusCode(), tt.code) } else { assert.Len(t, tt.wantLen, got) + for _, o := range got { + switch v := o.(type) { + case *provisionerExtensionOption: + assert.Equals(t, v.Type, int(TypeAWS)) + assert.Equals(t, v.Name, tt.aws.GetName()) + assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) + assert.Len(t, 2, v.KeyValuePairs) + case profileDefaultDuration: + assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) + case commonNameValidator: + assert.Equals(t, string(v), tt.args.cn) + case defaultPublicKeyValidator: + case *validityValidator: + assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) + case ipAddressesValidator: + assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) + case emailAddressesValidator: + assert.Equals(t, v, nil) + case urisValidator: + assert.Equals(t, v, nil) + case dnsNamesValidator: + assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + } + } } }) } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 6a226e09..500ab2bb 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -432,7 +432,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { wantErr bool }{ {"ok", p1, args{t1}, 4, http.StatusOK, false}, - {"ok", p2, args{t2}, 6, http.StatusOK, false}, + {"ok", p2, args{t2}, 9, http.StatusOK, false}, {"ok", p1, args{t11}, 4, http.StatusOK, false}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, @@ -456,6 +456,33 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, sc.StatusCode(), tt.code) } else { assert.Len(t, tt.wantLen, got) + for _, o := range got { + switch v := o.(type) { + case *provisionerExtensionOption: + assert.Equals(t, v.Type, int(TypeAzure)) + assert.Equals(t, v.Name, tt.azure.GetName()) + assert.Equals(t, v.CredentialID, tt.azure.TenantID) + assert.Len(t, 0, v.KeyValuePairs) + case profileDefaultDuration: + assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) + case commonNameValidator: + assert.Equals(t, string(v), "virtualMachine") + case defaultPublicKeyValidator: + case *validityValidator: + assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) + case ipAddressesValidator: + assert.Equals(t, v, nil) + case emailAddressesValidator: + assert.Equals(t, v, nil) + case urisValidator: + assert.Equals(t, v, nil) + case dnsNamesValidator: + assert.Equals(t, []string(v), []string{"virtualMachine"}) + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + } + } } }) } diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 34a09003..a9c781d2 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -516,7 +516,7 @@ func TestGCP_AuthorizeSign(t *testing.T) { wantErr bool }{ {"ok", p1, args{t1}, 4, http.StatusOK, false}, - {"ok", p2, args{t2}, 6, http.StatusOK, false}, + {"ok", p2, args{t2}, 9, http.StatusOK, false}, {"ok", p3, args{t3}, 4, http.StatusOK, false}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, @@ -545,6 +545,33 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, sc.StatusCode(), tt.code) } else { assert.Len(t, tt.wantLen, got) + for _, o := range got { + switch v := o.(type) { + case *provisionerExtensionOption: + assert.Equals(t, v.Type, int(TypeGCP)) + assert.Equals(t, v.Name, tt.gcp.GetName()) + assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) + assert.Len(t, 4, v.KeyValuePairs) + case profileDefaultDuration: + assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) + case commonNameSliceValidator: + assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) + case defaultPublicKeyValidator: + case *validityValidator: + assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) + case ipAddressesValidator: + assert.Equals(t, v, nil) + case emailAddressesValidator: + assert.Equals(t, v, nil) + case urisValidator: + assert.Equals(t, v, nil) + case dnsNamesValidator: + assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + } + } } }) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 24630ad3..cc513dc6 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -157,8 +157,8 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), // validators commonNameValidator(claims.Subject), - defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, + defaultSANsValidator(claims.SANs), newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), }, nil } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index ed97d8f1..5644f49f 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -6,7 +6,6 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "net" "net/http" "strings" "testing" @@ -253,18 +252,36 @@ func TestJWK_AuthorizeSign(t *testing.T) { token string } tests := []struct { - name string - prov *JWK - args args - code int - err error - dns []string - emails []string - ips []net.IP + name string + prov *JWK + args args + code int + err error + sans []string }{ - {name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")}, - {"ok-sans", p1, args{t1}, http.StatusOK, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}}, - {"ok-no-sans", p1, args{t2}, http.StatusOK, nil, []string{"subject"}, []string{}, []net.IP{}}, + { + name: "fail-signature", + prov: p1, + args: args{failSig}, + code: http.StatusUnauthorized, + err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive"), + }, + { + name: "ok-sans", + prov: p1, + args: args{t1}, + code: http.StatusOK, + err: nil, + sans: []string{"127.0.0.1", "max@smallstep.com", "foo"}, + }, + { + name: "ok-no-sans", + prov: p1, + args: args{t2}, + code: http.StatusOK, + err: nil, + sans: []string{"subject"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -278,7 +295,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } } else { if assert.NotNil(t, got) { - assert.Len(t, 8, got) + assert.Len(t, 6, got) for _, o := range got { switch v := o.(type) { case *provisionerExtensionOption: @@ -291,15 +308,11 @@ func TestJWK_AuthorizeSign(t *testing.T) { case commonNameValidator: assert.Equals(t, string(v), "subject") case defaultPublicKeyValidator: - case dnsNamesValidator: - assert.Equals(t, []string(v), tt.dns) - case emailAddressesValidator: - assert.Equals(t, []string(v), tt.emails) - case ipAddressesValidator: - assert.Equals(t, []net.IP(v), tt.ips) case *validityValidator: assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + case defaultSANsValidator: + assert.Equals(t, []string(v), tt.sans) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index dac0bfb8..00b5c7f4 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -426,7 +426,15 @@ func (o *provisionerExtensionOption) Option(Options) x509util.WithOption { if err != nil { return err } - crt.ExtraExtensions = append(crt.ExtraExtensions, ext) + // NOTE: HACK. + // Prepend the provisioner extension. In the auth.Sign code we will + // force the resulting certificate to only have one extension, the + // first stepOIDProvisioner that is found in the ExtraExtensions. + // A client could pass a csr containing a malicious stepOIDProvisioner + // ExtraExtension. If we were to append (rather than prepend) the correct + // stepOIDProvisioner extension, then the resulting certificate would + // contain the malicious extension, rather than the one applied by step-ca. + crt.ExtraExtensions = append([]pkix.Extension{ext}, crt.ExtraExtensions...) return nil } } diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index 62e285e9..3e02853a 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -356,6 +356,7 @@ func Test_ExtraExtsEnforcer_Enforce(t *testing.T) { e1 := pkix.Extension{Id: []int{1, 2, 3, 4, 5}, Critical: false, Value: []byte("foo")} e2 := pkix.Extension{Id: []int{2, 2, 2}, Critical: false, Value: []byte("bar")} stepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("baz")} + fakeStepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("zap")} type test struct { cert *x509.Certificate check func(*x509.Certificate) @@ -379,7 +380,7 @@ func Test_ExtraExtsEnforcer_Enforce(t *testing.T) { }, "ok/step-provisioner-ext": func() test { return test{ - cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, stepExt, e2}}, + cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, stepExt, fakeStepExt, e2}}, check: func(cert *x509.Certificate) { assert.Equals(t, len(cert.ExtraExtensions), 1) assert.Equals(t, cert.ExtraExtensions[0], stepExt) @@ -668,6 +669,47 @@ func Test_profileDefaultDuration_Option(t *testing.T) { } } +func Test_newProvisionerExtension_Option(t *testing.T) { + type test struct { + cert *x509.Certificate + valid func(*x509.Certificate) + } + tests := map[string]func() test{ + "ok/one-element": func() test { + return test{ + cert: new(x509.Certificate), + valid: func(cert *x509.Certificate) { + if assert.Len(t, 1, cert.ExtraExtensions) { + ext := cert.ExtraExtensions[0] + assert.Equals(t, ext.Id, stepOIDProvisioner) + } + }, + } + }, + "ok/prepend": func() test { + return test{ + cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, + valid: func(cert *x509.Certificate) { + if assert.Len(t, 3, cert.ExtraExtensions) { + ext := cert.ExtraExtensions[0] + assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.False(t, ext.Critical) + } + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tt := run() + prof := &x509util.Leaf{} + prof.SetSubject(tt.cert) + assert.FatalError(t, newProvisionerExtensionOption(TypeJWK, "foo", "bar", "baz", "zap").Option(Options{})(prof)) + tt.valid(prof.Subject()) + }) + } +} + func Test_profileLimitDuration_Option(t *testing.T) { n, fn := mockNow() defer fn() diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 3ebaeb6b..24f3e71e 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -2,7 +2,6 @@ package provisioner import ( "context" - "net" "net/http" "testing" "time" @@ -407,13 +406,11 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) type test struct { - p *X5C - token string - code int - err error - dns []string - emails []string - ips []net.IP + p *X5C + token string + code int + err error + sans []string } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { @@ -434,11 +431,9 @@ func TestX5C_AuthorizeSign(t *testing.T) { withX5CHdr(certs)) assert.FatalError(t, err) return test{ - p: p, - token: tok, - dns: []string{"foo"}, - emails: []string{}, - ips: []net.IP{}, + p: p, + token: tok, + sans: []string{"foo"}, } }, "ok/multi-sans": func(t *testing.T) test { @@ -449,11 +444,9 @@ func TestX5C_AuthorizeSign(t *testing.T) { withX5CHdr(certs)) assert.FatalError(t, err) return test{ - p: p, - token: tok, - dns: []string{"foo"}, - emails: []string{"max@smallstep.com"}, - ips: []net.IP{net.ParseIP("127.0.0.1")}, + p: p, + token: tok, + sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, } }, } @@ -470,7 +463,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 + assert.Equals(t, len(opts), 6) for _, o := range opts { switch v := o.(type) { case *provisionerExtensionOption: @@ -487,21 +480,15 @@ func TestX5C_AuthorizeSign(t *testing.T) { case commonNameValidator: assert.Equals(t, string(v), "foo") case defaultPublicKeyValidator: - case dnsNamesValidator: - assert.Equals(t, []string(v), tc.dns) - case emailAddressesValidator: - assert.Equals(t, []string(v), tc.emails) - case ipAddressesValidator: - assert.Equals(t, []net.IP(v), tc.ips) + case defaultSANsValidator: + assert.Equals(t, []string(v), tc.sans) case *validityValidator: assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 8) } } } diff --git a/authority/tls_test.go b/authority/tls_test.go index 183c3083..807f970d 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -89,6 +89,17 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques return csr } +func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) { + return func(csr *x509.CertificateRequest) { + csr.ExtraExtensions = exts + } +} + +type basicConstraints struct { + IsCA bool `asn1:"optional"` + MaxPathLen int `asn1:"optional,default:-1"` +} + func TestAuthority_Sign(t *testing.T) { pub, priv, err := keys.GenerateDefaultKeyPair() assert.FatalError(t, err) @@ -271,7 +282,16 @@ ZYtQ9Ot36qc= } }, "ok with enforced modifier": func(t *testing.T) *signTest { - csr := getCSR(t, priv) + bcExt := pkix.Extension{} + bcExt.Id = asn1.ObjectIdentifier{2, 5, 29, 19} + bcExt.Critical = false + bcExt.Value, err = asn1.Marshal(basicConstraints{IsCA: true, MaxPathLen: 4}) + assert.FatalError(t, err) + + csr := getCSR(t, priv, setExtraExtsCSR([]pkix.Extension{ + bcExt, + {Id: stepOIDProvisioner, Value: []byte("foo")}, + {Id: []int{1, 1, 1}, Value: []byte("bar")}})) now := time.Now().UTC() enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{ NotBefore: now, @@ -347,19 +367,26 @@ ZYtQ9Ot36qc= // Verify Provisioner OID found := 0 for _, ext := range leaf.Extensions { - id := ext.Id.String() - if id != stepOIDProvisioner.String() { - continue + switch { + case ext.Id.Equal(stepOIDProvisioner): + found++ + val := stepProvisionerASN1{} + _, err := asn1.Unmarshal(ext.Value, &val) + assert.FatalError(t, err) + assert.Equals(t, val.Type, provisionerTypeJWK) + assert.Equals(t, val.Name, []byte(p.Name)) + assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) + // Basic Constraints + case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})): + val := basicConstraints{} + _, err := asn1.Unmarshal(ext.Value, &val) + assert.FatalError(t, err) + assert.False(t, val.IsCA, false) + assert.Equals(t, val.MaxPathLen, 0) } - found++ - val := stepProvisionerASN1{} - _, err := asn1.Unmarshal(ext.Value, &val) - assert.FatalError(t, err) - assert.Equals(t, val.Type, provisionerTypeJWK) - assert.Equals(t, val.Name, []byte(p.Name)) - assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) } assert.Equals(t, found, 1) + assert.Len(t, 6, leaf.Extensions) realIntermediate, err := x509.ParseCertificate(a.x509Issuer.Raw) assert.FatalError(t, err)