Add context to the Authorize method.

Fix tests.
This commit is contained in:
Mariano Cano 2019-07-29 12:34:27 -07:00
parent 2127d09ef3
commit e1cd5ee8c3
5 changed files with 62 additions and 50 deletions

View file

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"context"
"crypto/x509" "crypto/x509"
"net/http" "net/http"
"strings" "strings"
@ -72,31 +73,37 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
return p, nil return p, nil
} }
// Authorize is a passthrough to AuthorizeSign. // Authorize grabs the method from the context and authorizes a signature
// NOTE: Authorize will be deprecated in a future release. Please use the // request by validating the one-time-token.
// context specific Authorize[Sign|Revoke|etc.] going forwards. func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) { var errContext = apiCtx{"ott": ott}
return a.AuthorizeSign(ott) switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod, provisioner.SignSSHMethod:
p, err := a.authorizeToken(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
// Call the provisioner AuthorizeSign method to apply provisioner specific
// auth claims and get the signing options.
opts, err := p.AuthorizeSign(context.Background(), ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
return opts, nil
case provisioner.RevokeMethod:
return nil, &apiError{errors.New("authorize: revoke method is not supported"), http.StatusInternalServerError, errContext}
default:
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
}
} }
// AuthorizeSign authorizes a signature request by validating and authenticating // AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request. // a OTT that must be sent w/ the request.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott} ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
return a.Authorize(ctx, ott)
p, err := a.authorizeToken(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
// Call the provisioner AuthorizeSign method to apply provisioner specific
// auth claims and get the signing options.
opts, err := p.AuthorizeSign(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
return opts, nil
} }
// authorizeRevoke authorizes a revocation request by validating and authenticating // authorizeRevoke authorizes a revocation request by validating and authenticating

View file

@ -1,11 +1,14 @@
package authority package authority
import ( import (
"context"
"crypto/x509" "crypto/x509"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
@ -73,7 +76,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a, auth: a,
ott: "foo", ott: "foo",
err: &apiError{errors.New("authorizeToken: error parsing token"), err: &apiError{errors.New("authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}}, http.StatusUnauthorized, apiCtx{"ott": "foo"}},
} }
}, },
"fail/prehistoric-token": func(t *testing.T) *authorizeTest { "fail/prehistoric-token": func(t *testing.T) *authorizeTest {
@ -92,7 +95,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"), err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, apiCtx{"ott": raw}},
} }
}, },
"fail/provisioner-not-found": func(t *testing.T) *authorizeTest { "fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
@ -114,7 +117,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"), err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, apiCtx{"ott": raw}},
} }
}, },
"ok/simpledb": func(t *testing.T) *authorizeTest { "ok/simpledb": func(t *testing.T) *authorizeTest {
@ -151,7 +154,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a, auth: _a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeToken: token already used"), err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, apiCtx{"ott": raw}},
} }
}, },
"ok/mockNoSQLDB": func(t *testing.T) *authorizeTest { "ok/mockNoSQLDB": func(t *testing.T) *authorizeTest {
@ -199,7 +202,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a, auth: _a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"), err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"),
http.StatusInternalServerError, context{"ott": raw}}, http.StatusInternalServerError, apiCtx{"ott": raw}},
} }
}, },
"fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest { "fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest {
@ -224,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a, auth: _a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeToken: token already used"), err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, apiCtx{"ott": raw}},
} }
}, },
} }
@ -391,7 +394,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
auth: a, auth: a,
ott: "foo", ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"), err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}}, http.StatusUnauthorized, apiCtx{"ott": "foo"}},
} }
}, },
"fail/invalid-subject": func(t *testing.T) *authorizeTest { "fail/invalid-subject": func(t *testing.T) *authorizeTest {
@ -409,7 +412,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"), err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, apiCtx{"ott": raw}},
} }
}, },
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
@ -484,7 +487,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
ott: "foo", ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"), err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}}, http.StatusUnauthorized, apiCtx{"ott": "foo"}},
} }
}, },
"fail/invalid-subject": func(t *testing.T) *authorizeTest { "fail/invalid-subject": func(t *testing.T) *authorizeTest {
@ -502,7 +505,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"), err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, apiCtx{"ott": raw}},
} }
}, },
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
@ -526,8 +529,8 @@ func TestAuthority_Authorize(t *testing.T) {
for name, genTestCase := range tests { for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
got, err := tc.auth.Authorize(tc.ott) got, err := tc.auth.Authorize(ctx, tc.ott)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.Nil(t, got) assert.Nil(t, got)
@ -577,7 +580,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a, auth: a,
crt: fooCrt, crt: fooCrt,
err: &apiError{errors.New("renew: force"), err: &apiError{errors.New("renew: force"),
http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}}, http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
} }
}, },
"fail/revoked": func(t *testing.T) *authorizeTest { "fail/revoked": func(t *testing.T) *authorizeTest {
@ -591,7 +594,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a, auth: a,
crt: fooCrt, crt: fooCrt,
err: &apiError{errors.New("renew: certificate has been revoked"), err: &apiError{errors.New("renew: certificate has been revoked"),
http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}}, http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
} }
}, },
"fail/load-provisioner": func(t *testing.T) *authorizeTest { "fail/load-provisioner": func(t *testing.T) *authorizeTest {
@ -605,7 +608,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a, auth: a,
crt: otherCrt, crt: otherCrt,
err: &apiError{errors.New("renew: provisioner not found"), err: &apiError{errors.New("renew: provisioner not found"),
http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}}, http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}},
} }
}, },
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest { "fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
@ -620,7 +623,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a, auth: a,
crt: renewDisabledCrt, crt: renewDisabledCrt,
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}}, http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}},
} }
}, },
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {

View file

@ -35,7 +35,7 @@ func TestGetEncryptedKey(t *testing.T) {
a: a, a: a,
kid: "foo", kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"), err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
http.StatusNotFound, context{}}, http.StatusNotFound, apiCtx{}},
} }
}, },
} }

View file

@ -19,8 +19,8 @@ func TestRoot(t *testing.T) {
sum string sum string
err *apiError err *apiError
}{ }{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}}, "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}}, "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
} }

View file

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/x509" "crypto/x509"
@ -102,7 +103,8 @@ func TestSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key) token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err) assert.FatalError(t, err)
extraOpts, err := a.Authorize(token) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
extraOpts, err := a.Authorize(ctx, token)
assert.FatalError(t, err) assert.FatalError(t, err)
type signTest struct { type signTest struct {
@ -123,7 +125,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: invalid certificate request"), err: &apiError{errors.New("sign: invalid certificate request"),
http.StatusBadRequest, http.StatusBadRequest,
context{"csr": csr, "signOptions": signOpts}, apiCtx{"csr": csr, "signOptions": signOpts},
}, },
} }
}, },
@ -137,7 +139,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: invalid extra option type string"), err: &apiError{errors.New("sign: invalid extra option type string"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, apiCtx{"csr": csr, "signOptions": signOpts},
}, },
} }
}, },
@ -152,7 +154,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"), err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, apiCtx{"csr": csr, "signOptions": signOpts},
}, },
} }
}, },
@ -167,7 +169,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"), err: &apiError{errors.New("sign: error creating new leaf certificate"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, apiCtx{"csr": csr, "signOptions": signOpts},
}, },
} }
}, },
@ -184,7 +186,7 @@ func TestSign(t *testing.T) {
signOpts: _signOpts, signOpts: _signOpts,
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
http.StatusUnauthorized, http.StatusUnauthorized,
context{"csr": csr, "signOptions": _signOpts}, apiCtx{"csr": csr, "signOptions": _signOpts},
}, },
} }
}, },
@ -199,7 +201,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
http.StatusUnauthorized, http.StatusUnauthorized,
context{"csr": csr, "signOptions": signOpts}, apiCtx{"csr": csr, "signOptions": signOpts},
}, },
} }
}, },
@ -210,7 +212,7 @@ func TestSign(t *testing.T) {
storeCertificate: func(crt *x509.Certificate) error { storeCertificate: func(crt *x509.Certificate) error {
return &apiError{errors.New("force"), return &apiError{errors.New("force"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}} apiCtx{"csr": csr, "signOptions": signOpts}}
}, },
} }
return &signTest{ return &signTest{
@ -220,7 +222,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: error storing certificate in db: force"), err: &apiError{errors.New("sign: error storing certificate in db: force"),
http.StatusInternalServerError, http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}, apiCtx{"csr": csr, "signOptions": signOpts},
}, },
} }
}, },
@ -373,7 +375,7 @@ func TestRenew(t *testing.T) {
auth: _a, auth: _a,
crt: crt, crt: crt,
err: &apiError{errors.New("error renewing certificate from existing server certificate"), err: &apiError{errors.New("error renewing certificate from existing server certificate"),
http.StatusInternalServerError, context{}}, http.StatusInternalServerError, apiCtx{}},
}, nil }, nil
}, },
"fail-unauthorized": func() (*renewTest, error) { "fail-unauthorized": func() (*renewTest, error) {
@ -568,7 +570,7 @@ func TestRevoke(t *testing.T) {
validAudience := []string{"https://test.ca.smallstep.com/revoke"} validAudience := []string{"https://test.ca.smallstep.com/revoke"}
now := time.Now().UTC() now := time.Now().UTC()
getCtx := func() map[string]interface{} { getCtx := func() map[string]interface{} {
return context{ return apiCtx{
"serialNumber": "sn", "serialNumber": "sn",
"reasonCode": reasonCode, "reasonCode": reasonCode,
"reason": reason, "reason": reason,