diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index ebd8e5a4..913c8a2b 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "github.com/smallstep/assert" @@ -221,39 +222,37 @@ func TestOIDC_authorizeToken(t *testing.T) { args args code int wantIssuer string - wantErr bool + expErr error }{ - {"ok1", p1, args{t1}, http.StatusOK, issuer, false}, - {"ok tenantid", p2, args{t2}, http.StatusOK, tenantIssuer, false}, - {"ok admin", p3, args{t3}, http.StatusOK, issuer, false}, - {"ok domain", p3, args{t4}, http.StatusOK, issuer, false}, - {"ok no email", p3, args{t5}, http.StatusOK, issuer, false}, - {"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, "", true}, - {"fail-key", p1, args{failKey}, http.StatusUnauthorized, "", true}, - {"fail-token", p1, args{failTok}, http.StatusUnauthorized, "", true}, - {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, "", true}, - {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, "", true}, - {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, "", true}, - {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, "", true}, - {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, "", true}, - {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, "", true}, + {"ok1", p1, args{t1}, http.StatusOK, issuer, nil}, + {"ok tenantid", p2, args{t2}, http.StatusOK, tenantIssuer, nil}, + {"ok admin", p3, args{t3}, http.StatusOK, issuer, nil}, + {"ok domain", p3, args{t4}, http.StatusOK, issuer, nil}, + {"ok no email", p3, args{t5}, http.StatusOK, issuer, nil}, + {"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: email "name@example.com" is not allowed`)}, + {"fail-key", p1, args{failKey}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; cannot validate oidc token`)}, + {"fail-token", p1, args{failTok}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; error parsing oidc token: invalid character '~' looking for beginning of value`)}, + {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; error parsing oidc token claims: invalid character '~' looking for beginning of value`)}, + {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, invalid issuer claim (iss)`)}, + {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, invalid audience claim (aud)`)}, + {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; cannot validate oidc token`)}, + {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, token is expired (exp)`)}, + {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, token not valid yet (nbf)`)}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.prov.authorizeToken(tt.args.token) - if (err != nil) != tt.wantErr { - fmt.Println(tt) - t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) - return - } - if err != nil { + if tt.expErr != nil { + require.Error(t, err) + require.EqualError(t, err, tt.expErr.Error()) + var sc render.StatusCodedError - assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tt.code) - assert.Nil(t, got) + require.ErrorAs(t, err, &sc, "error does not implement StatusCodedError interface") + require.Equal(t, tt.code, sc.StatusCode()) + require.Nil(t, got) } else { - assert.NotNil(t, got) - assert.Equals(t, got.Issuer, tt.wantIssuer) + require.NotNil(t, got) + require.Equal(t, tt.wantIssuer, got.Issuer) } }) }