diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 2eefd331..c3868e5f 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -5,7 +5,6 @@ import ( "crypto/ed25519" "crypto/rsa" "crypto/x509" - "crypto/x509/pkix" "encoding/json" "net" "net/http" @@ -425,18 +424,6 @@ func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) e return v.policyEngine.IsX509CertificateAllowed(cert) } -// var ( -// stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} -// stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) -// ) - -// type stepProvisionerASN1 struct { -// Type int -// Name []byte -// CredentialID []byte -// KeyValuePairs []string `asn1:"optional,omitempty"` -// } - type forceCNOption struct { ForceCN bool } @@ -481,13 +468,14 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption if err != nil { return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } - // Prepend the provisioner extension. In the auth.Sign code we will - // force the resulting certificate to only have one extension, the - // first stepOIDProvisioner that is found in the ExtraExtensions. - // A client could pass a csr containing a malicious stepOIDProvisioner - // ExtraExtension. If we were to append (rather than prepend) the correct - // stepOIDProvisioner extension, then the resulting certificate would - // contain the malicious extension, rather than the one applied by step-ca. - cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) + // Replace or append the provisioner extension to avoid the inclusions of + // malicious stepOIDProvisioner using templates. + for i, e := range cert.ExtraExtensions { + if e.Id.Equal(StepOIDProvisioner) { + cert.ExtraExtensions[i] = ext + return nil + } + } + cert.ExtraExtensions = append(cert.ExtraExtensions, ext) return nil } diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index fc4d675a..198462c7 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -3,6 +3,7 @@ package provisioner import ( "crypto/x509" "crypto/x509/pkix" + "encoding/asn1" "fmt" "net" "net/url" @@ -625,6 +626,16 @@ func Test_profileDefaultDuration_Option(t *testing.T) { } func Test_newProvisionerExtension_Option(t *testing.T) { + expectedValue, err := asn1.Marshal(extensionASN1{ + Type: int(TypeJWK), + Name: []byte("name"), + CredentialID: []byte("credentialId"), + KeyValuePairs: []string{"key", "value"}, + }) + if err != nil { + t.Fatal(err) + } + type test struct { cert *x509.Certificate valid func(*x509.Certificate) @@ -636,18 +647,22 @@ func Test_newProvisionerExtension_Option(t *testing.T) { valid: func(cert *x509.Certificate) { if assert.Len(t, 1, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, StepOIDProvisioner) + assert.Equals(t, StepOIDProvisioner, ext.Id) + assert.Equals(t, expectedValue, ext.Value) + assert.False(t, ext.Critical) + } }, } }, - "ok/prepend": func() test { + "ok/replace": func() test { return test{ cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, valid: func(cert *x509.Certificate) { - if assert.Len(t, 3, cert.ExtraExtensions) { + if assert.Len(t, 2, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, StepOIDProvisioner) + assert.Equals(t, StepOIDProvisioner, ext.Id) + assert.Equals(t, expectedValue, ext.Value) assert.False(t, ext.Critical) } }, @@ -657,7 +672,7 @@ func Test_newProvisionerExtension_Option(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() - assert.FatalError(t, newProvisionerExtensionOption(TypeJWK, "foo", "bar", "baz", "zap").Modify(tt.cert, SignOptions{})) + assert.FatalError(t, newProvisionerExtensionOption(TypeJWK, "name", "credentialId", "key", "value").Modify(tt.cert, SignOptions{})) tt.valid(tt.cert) }) }