Add context to the Authorize method.

Fix tests.
This commit is contained in:
Mariano Cano 2019-07-29 12:34:27 -07:00
parent 2127d09ef3
commit e1cd5ee8c3
5 changed files with 62 additions and 50 deletions

View file

@ -1,6 +1,7 @@
package authority
import (
"context"
"crypto/x509"
"net/http"
"strings"
@ -72,18 +73,12 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
return p, nil
}
// Authorize is a passthrough to AuthorizeSign.
// NOTE: Authorize will be deprecated in a future release. Please use the
// context specific Authorize[Sign|Revoke|etc.] going forwards.
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
return a.AuthorizeSign(ott)
}
// AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
// Authorize grabs the method from the context and authorizes a signature
// request by validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod, provisioner.SignSSHMethod:
p, err := a.authorizeToken(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
@ -91,12 +86,24 @@ func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error)
// Call the provisioner AuthorizeSign method to apply provisioner specific
// auth claims and get the signing options.
opts, err := p.AuthorizeSign(ott)
opts, err := p.AuthorizeSign(context.Background(), ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
return opts, nil
case provisioner.RevokeMethod:
return nil, &apiError{errors.New("authorize: revoke method is not supported"), http.StatusInternalServerError, errContext}
default:
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
}
}
// AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
return a.Authorize(ctx, ott)
}
// authorizeRevoke authorizes a revocation request by validating and authenticating

View file

@ -1,11 +1,14 @@
package authority
import (
"context"
"crypto/x509"
"net/http"
"testing"
"time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/pemutil"
@ -73,7 +76,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
}
},
"fail/prehistoric-token": func(t *testing.T) *authorizeTest {
@ -92,7 +95,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
@ -114,7 +117,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok/simpledb": func(t *testing.T) *authorizeTest {
@ -151,7 +154,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a,
ott: raw,
err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok/mockNoSQLDB": func(t *testing.T) *authorizeTest {
@ -199,7 +202,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a,
ott: raw,
err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"),
http.StatusInternalServerError, context{"ott": raw}},
http.StatusInternalServerError, apiCtx{"ott": raw}},
}
},
"fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest {
@ -224,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a,
ott: raw,
err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
}
@ -391,7 +394,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
}
},
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
@ -409,7 +412,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok": func(t *testing.T) *authorizeTest {
@ -484,7 +487,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
}
},
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
@ -502,7 +505,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok": func(t *testing.T) *authorizeTest {
@ -526,8 +529,8 @@ func TestAuthority_Authorize(t *testing.T) {
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
got, err := tc.auth.Authorize(tc.ott)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
got, err := tc.auth.Authorize(ctx, tc.ott)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Nil(t, got)
@ -577,7 +580,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: fooCrt,
err: &apiError{errors.New("renew: force"),
http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}},
http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
}
},
"fail/revoked": func(t *testing.T) *authorizeTest {
@ -591,7 +594,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: fooCrt,
err: &apiError{errors.New("renew: certificate has been revoked"),
http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}},
http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
}
},
"fail/load-provisioner": func(t *testing.T) *authorizeTest {
@ -605,7 +608,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: otherCrt,
err: &apiError{errors.New("renew: provisioner not found"),
http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}},
http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}},
}
},
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
@ -620,7 +623,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: renewDisabledCrt,
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}},
http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}},
}
},
"ok": func(t *testing.T) *authorizeTest {

View file

@ -35,7 +35,7 @@ func TestGetEncryptedKey(t *testing.T) {
a: a,
kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
http.StatusNotFound, context{}},
http.StatusNotFound, apiCtx{}},
}
},
}

View file

@ -19,8 +19,8 @@ func TestRoot(t *testing.T) {
sum string
err *apiError
}{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
}

View file

@ -1,6 +1,7 @@
package authority
import (
"context"
"crypto/rand"
"crypto/sha1"
"crypto/x509"
@ -102,7 +103,8 @@ func TestSign(t *testing.T) {
assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err)
extraOpts, err := a.Authorize(token)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
extraOpts, err := a.Authorize(ctx, token)
assert.FatalError(t, err)
type signTest struct {
@ -123,7 +125,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid certificate request"),
http.StatusBadRequest,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -137,7 +139,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid extra option type string"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -152,7 +154,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -167,7 +169,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -184,7 +186,7 @@ func TestSign(t *testing.T) {
signOpts: _signOpts,
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
http.StatusUnauthorized,
context{"csr": csr, "signOptions": _signOpts},
apiCtx{"csr": csr, "signOptions": _signOpts},
},
}
},
@ -199,7 +201,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
http.StatusUnauthorized,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -210,7 +212,7 @@ func TestSign(t *testing.T) {
storeCertificate: func(crt *x509.Certificate) error {
return &apiError{errors.New("force"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}}
apiCtx{"csr": csr, "signOptions": signOpts}}
},
}
return &signTest{
@ -220,7 +222,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: error storing certificate in db: force"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -373,7 +375,7 @@ func TestRenew(t *testing.T) {
auth: _a,
crt: crt,
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
http.StatusInternalServerError, context{}},
http.StatusInternalServerError, apiCtx{}},
}, nil
},
"fail-unauthorized": func() (*renewTest, error) {
@ -568,7 +570,7 @@ func TestRevoke(t *testing.T) {
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
now := time.Now().UTC()
getCtx := func() map[string]interface{} {
return context{
return apiCtx{
"serialNumber": "sn",
"reasonCode": reasonCode,
"reason": reason,