diff --git a/authority/tls.go b/authority/tls.go index da15ed51..554e9478 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -30,35 +30,34 @@ type SignOptions struct { } var ( - stepOIDRoot = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 4, 1, 37476, 9000, 64}) - stepOIDProvisioner = asn1.ObjectIdentifier(append([]int(nil), append(stepOIDRoot, 1)...)) - stepOIDProvisionerName = asn1.ObjectIdentifier(append([]int(nil), append(stepOIDProvisioner, 1)...)) - stepOIDProvisionerKeyID = asn1.ObjectIdentifier(append([]int(nil), append(stepOIDProvisioner, 2)...)) + stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) ) +type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte +} + +const provisionerTypeJWK = 1 + func withProvisionerOID(name, kid string) x509util.WithOption { return func(p x509util.Profile) error { crt := p.Subject() - irw := asn1.RawValue{Tag: asn1.TagGeneralString, Class: asn1.ClassPrivate, Bytes: []byte(name)} - krw := asn1.RawValue{Tag: asn1.TagGeneralString, Class: asn1.ClassPrivate, Bytes: []byte(kid)} - - irwb, err := asn1.Marshal(irw) - if err != nil { - return err - } - krwb, err := asn1.Marshal(krw) + b, err := asn1.Marshal(stepProvisionerASN1{ + Type: provisionerTypeJWK, + Name: []byte(name), + CredentialID: []byte(kid), + }) if err != nil { return err } crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ - Id: stepOIDProvisionerName, + Id: stepOIDProvisioner, Critical: false, - Value: irwb, - }, pkix.Extension{ - Id: stepOIDProvisionerKeyID, - Critical: false, - Value: krwb, + Value: b, }) return nil diff --git a/authority/tls_test.go b/authority/tls_test.go index ef4f8998..d5876716 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -148,9 +148,6 @@ func TestSign(t *testing.T) { } for name, genTestCase := range tests { - if name != "ok" { - continue - } t.Run(name, func(t *testing.T) { tc := genTestCase(t) @@ -199,22 +196,18 @@ func TestSign(t *testing.T) { found := 0 for _, ext := range leaf.Extensions { id := ext.Id.String() - if id != stepOIDProvisionerName.String() && id != stepOIDProvisionerKeyID.String() { + if id != stepOIDProvisioner.String() { continue } found++ - rw := asn1.RawValue{} - _, err := asn1.Unmarshal(ext.Value, &rw) + val := stepProvisionerASN1{} + _, err := asn1.Unmarshal(ext.Value, &val) assert.FatalError(t, err) - assert.Equals(t, rw.Tag, asn1.TagGeneralString) - assert.Equals(t, rw.Class, asn1.ClassPrivate) - if id == stepOIDProvisionerName.String() { - assert.Equals(t, string(rw.Bytes), p.Issuer) - } else { - assert.Equals(t, string(rw.Bytes), p.Key.KeyID) - } + assert.Equals(t, val.Type, provisionerTypeJWK) + assert.Equals(t, val.Name, []byte(p.Issuer)) + assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) } - assert.Equals(t, found, 2) + assert.Equals(t, found, 1) realIntermediate, err := x509.ParseCertificate(a.intermediateIdentity.Crt.Raw) assert.FatalError(t, err)