withProvisionerOID and unit test

This commit is contained in:
max furman 2018-10-25 23:49:23 -07:00
parent 283dc42904
commit a4a461466b
4 changed files with 76 additions and 55 deletions

View file

@ -96,7 +96,6 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
&commonNameClaim{claims.Subject}, &commonNameClaim{claims.Subject},
&dnsNamesClaim{claims.Subject}, &dnsNamesClaim{claims.Subject},
&ipAddressesClaim{claims.Subject}, &ipAddressesClaim{claims.Subject},
withIssuerAlternativeNameExtension(p.Issuer + ":" + p.Key.KeyID),
p, p,
} }

View file

@ -124,6 +124,7 @@ func (p *Provisioner) getTLSApps(so SignOptions) ([]x509util.WithOption, []certC
return []x509util.WithOption{ return []x509util.WithOption{
x509util.WithNotBeforeAfterDuration(so.NotBefore, x509util.WithNotBeforeAfterDuration(so.NotBefore,
so.NotAfter, c.DefaultTLSCertDuration()), so.NotAfter, c.DefaultTLSCertDuration()),
withProvisionerOID(p.Issuer, p.Key.KeyID),
}, []certClaim{ }, []certClaim{
&certTemporalClaim{ &certTemporalClaim{
min: c.MinTLSCertDuration(), min: c.MinTLSCertDuration(),

View file

@ -29,26 +29,36 @@ type SignOptions struct {
NotBefore time.Time `json:"notBefore"` NotBefore time.Time `json:"notBefore"`
} }
func withIssuerAlternativeNameExtension(name string) x509util.WithOption { var (
stepOIDRoot = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 4, 1, 37476, 9000, 64})
stepOIDProvisioner = asn1.ObjectIdentifier(append([]int(nil), append(stepOIDRoot, 1)...))
stepOIDProvisionerName = asn1.ObjectIdentifier(append([]int(nil), append(stepOIDProvisioner, 1)...))
stepOIDProvisionerKeyID = asn1.ObjectIdentifier(append([]int(nil), append(stepOIDProvisioner, 2)...))
)
func withProvisionerOID(name, kid string) x509util.WithOption {
return func(p x509util.Profile) error { return func(p x509util.Profile) error {
crt := p.Subject() crt := p.Subject()
iatExt := []asn1.RawValue{ irw := asn1.RawValue{Tag: asn1.TagGeneralString, Class: asn1.ClassPrivate, Bytes: []byte(name)}
asn1.RawValue{ krw := asn1.RawValue{Tag: asn1.TagGeneralString, Class: asn1.ClassPrivate, Bytes: []byte(kid)}
Tag: 2,
Class: 2,
Bytes: []byte(name),
},
}
iatExtBytes, err := asn1.Marshal(iatExt)
if err != nil {
return &apiError{err, http.StatusInternalServerError, nil}
}
irwb, err := asn1.Marshal(irw)
if err != nil {
return err
}
krwb, err := asn1.Marshal(krw)
if err != nil {
return err
}
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
Id: []int{2, 5, 9, 18}, Id: stepOIDProvisionerName,
Critical: false, Critical: false,
Value: iatExtBytes, Value: irwb,
}, pkix.Extension{
Id: stepOIDProvisionerKeyID,
Critical: false,
Value: krwb,
}) })
return nil return nil

View file

@ -5,6 +5,7 @@ import (
"crypto/sha1" "crypto/sha1"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1"
"fmt" "fmt"
"net/http" "net/http"
"testing" "testing"
@ -50,6 +51,8 @@ func TestSign(t *testing.T) {
NotAfter: nb.Add(time.Minute * 5), NotAfter: nb.Add(time.Minute * 5),
} }
p := a.config.AuthorityConfig.Provisioners[1]
type signTest struct { type signTest struct {
auth *Authority auth *Authority
csr *x509.CertificateRequest csr *x509.CertificateRequest
@ -62,13 +65,10 @@ func TestSign(t *testing.T) {
csr := getCSR(t, priv) csr := getCSR(t, priv)
csr.Raw = []byte("foo") csr.Raw = []byte("foo")
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
extraOpts: []interface{}{ extraOpts: []interface{}{p, "42"},
withIssuerAlternativeNameExtension("baz"), signOpts: signOpts,
"42",
},
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid extra option type string"), err: &apiError{errors.New("sign: invalid extra option type string"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, context{"csr": csr, "signOptions": signOpts},
@ -79,12 +79,10 @@ func TestSign(t *testing.T) {
csr := getCSR(t, priv) csr := getCSR(t, priv)
csr.Raw = []byte("foo") csr.Raw = []byte("foo")
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
extraOpts: []interface{}{ extraOpts: []interface{}{p},
withIssuerAlternativeNameExtension("baz"), signOpts: signOpts,
},
signOpts: signOpts,
err: &apiError{errors.New("sign: error converting x509 csr to stepx509 csr"), err: &apiError{errors.New("sign: error converting x509 csr to stepx509 csr"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, context{"csr": csr, "signOptions": signOpts},
@ -96,13 +94,10 @@ func TestSign(t *testing.T) {
_a.config.AuthorityConfig.Template = nil _a.config.AuthorityConfig.Template = nil
csr := getCSR(t, priv) csr := getCSR(t, priv)
return &signTest{ return &signTest{
auth: _a, auth: _a,
csr: csr, csr: csr,
extraOpts: []interface{}{ extraOpts: []interface{}{p},
withIssuerAlternativeNameExtension("baz"), signOpts: signOpts,
a.config.AuthorityConfig.Provisioners[1],
},
signOpts: signOpts,
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"), err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, context{"csr": csr, "signOptions": signOpts},
@ -114,12 +109,10 @@ func TestSign(t *testing.T) {
_a.intermediateIdentity.Key = nil _a.intermediateIdentity.Key = nil
csr := getCSR(t, priv) csr := getCSR(t, priv)
return &signTest{ return &signTest{
auth: _a, auth: _a,
csr: csr, csr: csr,
extraOpts: []interface{}{ extraOpts: []interface{}{p},
withIssuerAlternativeNameExtension("baz"), signOpts: signOpts,
},
signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"), err: &apiError{errors.New("sign: error creating new leaf certificate"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, context{"csr": csr, "signOptions": signOpts},
@ -133,13 +126,10 @@ func TestSign(t *testing.T) {
NotAfter: nb.Add(time.Hour * 25), NotAfter: nb.Add(time.Hour * 25),
} }
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
extraOpts: []interface{}{ extraOpts: []interface{}{p},
withIssuerAlternativeNameExtension("baz"), signOpts: _signOpts,
a.config.AuthorityConfig.Provisioners[1],
},
signOpts: _signOpts,
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
http.StatusUnauthorized, http.StatusUnauthorized,
context{"csr": csr, "signOptions": _signOpts}, context{"csr": csr, "signOptions": _signOpts},
@ -149,18 +139,18 @@ func TestSign(t *testing.T) {
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv) csr := getCSR(t, priv)
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
extraOpts: []interface{}{ extraOpts: []interface{}{p},
withIssuerAlternativeNameExtension("baz"), signOpts: signOpts,
a.config.AuthorityConfig.Provisioners[1],
},
signOpts: signOpts,
} }
}, },
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {
if name != "ok" {
continue
}
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
@ -205,6 +195,27 @@ func TestSign(t *testing.T) {
assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId) assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId)
// Verify Provisioner OID
found := 0
for _, ext := range leaf.Extensions {
id := ext.Id.String()
if id != stepOIDProvisionerName.String() && id != stepOIDProvisionerKeyID.String() {
continue
}
found++
rw := asn1.RawValue{}
_, err := asn1.Unmarshal(ext.Value, &rw)
assert.FatalError(t, err)
assert.Equals(t, rw.Tag, asn1.TagGeneralString)
assert.Equals(t, rw.Class, asn1.ClassPrivate)
if id == stepOIDProvisionerName.String() {
assert.Equals(t, string(rw.Bytes), p.Issuer)
} else {
assert.Equals(t, string(rw.Bytes), p.Key.KeyID)
}
}
assert.Equals(t, found, 2)
realIntermediate, err := x509.ParseCertificate(a.intermediateIdentity.Crt.Raw) realIntermediate, err := x509.ParseCertificate(a.intermediateIdentity.Crt.Raw)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate) assert.Equals(t, intermediate, realIntermediate)