Merge pull request #3742 from sagikazarmark/fix-aud-claim-list
Accept list of strings in audience claim in token auth
This commit is contained in:
commit
29b5e79f82
6 changed files with 162 additions and 11 deletions
|
@ -180,7 +180,7 @@ func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAcc
|
||||||
claimSet := token.ClaimSet{
|
claimSet := token.ClaimSet{
|
||||||
Issuer: issuer.Issuer,
|
Issuer: issuer.Issuer,
|
||||||
Subject: subject,
|
Subject: subject,
|
||||||
Audience: audience,
|
Audience: []string{audience},
|
||||||
Expiration: now.Add(exp).Unix(),
|
Expiration: now.Add(exp).Unix(),
|
||||||
NotBefore: now.Unix(),
|
NotBefore: now.Unix(),
|
||||||
IssuedAt: now.Unix(),
|
IssuedAt: now.Unix(),
|
||||||
|
|
|
@ -42,13 +42,13 @@ type ResourceActions struct {
|
||||||
// ClaimSet describes the main section of a JSON Web Token.
|
// ClaimSet describes the main section of a JSON Web Token.
|
||||||
type ClaimSet struct {
|
type ClaimSet struct {
|
||||||
// Public claims
|
// Public claims
|
||||||
Issuer string `json:"iss"`
|
Issuer string `json:"iss"`
|
||||||
Subject string `json:"sub"`
|
Subject string `json:"sub"`
|
||||||
Audience string `json:"aud"`
|
Audience AudienceList `json:"aud"`
|
||||||
Expiration int64 `json:"exp"`
|
Expiration int64 `json:"exp"`
|
||||||
NotBefore int64 `json:"nbf"`
|
NotBefore int64 `json:"nbf"`
|
||||||
IssuedAt int64 `json:"iat"`
|
IssuedAt int64 `json:"iat"`
|
||||||
JWTID string `json:"jti"`
|
JWTID string `json:"jti"`
|
||||||
|
|
||||||
// Private claims
|
// Private claims
|
||||||
Access []*ResourceActions `json:"access"`
|
Access []*ResourceActions `json:"access"`
|
||||||
|
@ -143,8 +143,8 @@ func (t *Token) Verify(verifyOpts VerifyOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the Audience claim is allowed.
|
// Verify that the Audience claim is allowed.
|
||||||
if !contains(verifyOpts.AcceptedAudiences, t.Claims.Audience) {
|
if !containsAny(verifyOpts.AcceptedAudiences, t.Claims.Audience) {
|
||||||
log.Infof("token intended for another audience: %q", t.Claims.Audience)
|
log.Infof("token intended for another audience: %v", t.Claims.Audience)
|
||||||
return ErrInvalidToken
|
return ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey l
|
||||||
claimSet := &ClaimSet{
|
claimSet := &ClaimSet{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Subject: "foo",
|
Subject: "foo",
|
||||||
Audience: audience,
|
Audience: []string{audience},
|
||||||
Expiration: exp.Unix(),
|
Expiration: exp.Unix(),
|
||||||
NotBefore: now.Unix(),
|
NotBefore: now.Unix(),
|
||||||
IssuedAt: now.Unix(),
|
IssuedAt: now.Unix(),
|
||||||
|
|
55
registry/auth/token/types.go
Normal file
55
registry/auth/token/types.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AudienceList is a slice of strings that can be deserialized from either a single string value or a list of strings.
|
||||||
|
type AudienceList []string
|
||||||
|
|
||||||
|
func (s *AudienceList) UnmarshalJSON(data []byte) (err error) {
|
||||||
|
var value interface{}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(data, &value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
*s = []string{v}
|
||||||
|
|
||||||
|
case []string:
|
||||||
|
*s = v
|
||||||
|
|
||||||
|
case []interface{}:
|
||||||
|
var ss []string
|
||||||
|
|
||||||
|
for _, vv := range v {
|
||||||
|
vs, ok := vv.(string)
|
||||||
|
if !ok {
|
||||||
|
return &json.UnsupportedTypeError{
|
||||||
|
Type: reflect.TypeOf(vv),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ss = append(ss, vs)
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = ss
|
||||||
|
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return &json.UnsupportedTypeError{
|
||||||
|
Type: reflect.TypeOf(v),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AudienceList) MarshalJSON() (b []byte, err error) {
|
||||||
|
return json.Marshal([]string(s))
|
||||||
|
}
|
85
registry/auth/token/types_test.go
Normal file
85
registry/auth/token/types_test.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAudienceList_Unmarshal(t *testing.T) {
|
||||||
|
t.Run("OK", func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
value string
|
||||||
|
expected AudienceList
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
value: `"audience"`,
|
||||||
|
expected: AudienceList{"audience"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
value: `["audience1", "audience2"]`,
|
||||||
|
expected: AudienceList{"audience1", "audience2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
value: `null`,
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
var actual AudienceList
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte(testCase.value), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertStringListEqual(t, testCase.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Error", func(t *testing.T) {
|
||||||
|
var actual AudienceList
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte("1234"), &actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected unmarshal to fail")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAudienceList_Marshal(t *testing.T) {
|
||||||
|
value := AudienceList{"audience"}
|
||||||
|
|
||||||
|
expected := `["audience"]`
|
||||||
|
|
||||||
|
actual, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected != string(actual) {
|
||||||
|
t.Errorf("expected marshaled list to be %v, got %v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertStringListEqual(t *testing.T, expected []string, actual []string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if len(expected) != len(actual) {
|
||||||
|
t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range expected {
|
||||||
|
if v != actual[i] {
|
||||||
|
t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
|
@ -56,3 +56,14 @@ func contains(ss []string, q string) bool {
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// containsAny returns true if any of q is found in ss.
|
||||||
|
func containsAny(ss []string, q []string) bool {
|
||||||
|
for _, s := range ss {
|
||||||
|
if contains(q, s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue