Merge pull request #3742 from sagikazarmark/fix-aud-claim-list

Accept list of strings in audience claim in token auth
This commit is contained in:
Milos Gajdos 2023-04-26 18:39:26 +01:00 committed by GitHub
commit 29b5e79f82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 162 additions and 11 deletions

View file

@ -180,7 +180,7 @@ func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAcc
claimSet := token.ClaimSet{
Issuer: issuer.Issuer,
Subject: subject,
Audience: audience,
Audience: []string{audience},
Expiration: now.Add(exp).Unix(),
NotBefore: now.Unix(),
IssuedAt: now.Unix(),

View file

@ -44,7 +44,7 @@ type ClaimSet struct {
// Public claims
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience string `json:"aud"`
Audience AudienceList `json:"aud"`
Expiration int64 `json:"exp"`
NotBefore int64 `json:"nbf"`
IssuedAt int64 `json:"iat"`
@ -143,8 +143,8 @@ func (t *Token) Verify(verifyOpts VerifyOptions) error {
}
// Verify that the Audience claim is allowed.
if !contains(verifyOpts.AcceptedAudiences, t.Claims.Audience) {
log.Infof("token intended for another audience: %q", t.Claims.Audience)
if !containsAny(verifyOpts.AcceptedAudiences, t.Claims.Audience) {
log.Infof("token intended for another audience: %v", t.Claims.Audience)
return ErrInvalidToken
}

View file

@ -116,7 +116,7 @@ func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey l
claimSet := &ClaimSet{
Issuer: issuer,
Subject: "foo",
Audience: audience,
Audience: []string{audience},
Expiration: exp.Unix(),
NotBefore: now.Unix(),
IssuedAt: now.Unix(),

View file

@ -0,0 +1,55 @@
package token
import (
"encoding/json"
"reflect"
)
// AudienceList is a slice of strings that can be deserialized from either a single string value or a list of strings.
type AudienceList []string
func (s *AudienceList) UnmarshalJSON(data []byte) (err error) {
var value interface{}
if err = json.Unmarshal(data, &value); err != nil {
return err
}
switch v := value.(type) {
case string:
*s = []string{v}
case []string:
*s = v
case []interface{}:
var ss []string
for _, vv := range v {
vs, ok := vv.(string)
if !ok {
return &json.UnsupportedTypeError{
Type: reflect.TypeOf(vv),
}
}
ss = append(ss, vs)
}
*s = ss
case nil:
return nil
default:
return &json.UnsupportedTypeError{
Type: reflect.TypeOf(v),
}
}
return
}
func (s AudienceList) MarshalJSON() (b []byte, err error) {
return json.Marshal([]string(s))
}

View file

@ -0,0 +1,85 @@
package token
import (
"encoding/json"
"testing"
)
func TestAudienceList_Unmarshal(t *testing.T) {
t.Run("OK", func(t *testing.T) {
testCases := []struct {
value string
expected AudienceList
}{
{
value: `"audience"`,
expected: AudienceList{"audience"},
},
{
value: `["audience1", "audience2"]`,
expected: AudienceList{"audience1", "audience2"},
},
{
value: `null`,
expected: nil,
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run("", func(t *testing.T) {
var actual AudienceList
err := json.Unmarshal([]byte(testCase.value), &actual)
if err != nil {
t.Fatal(err)
}
assertStringListEqual(t, testCase.expected, actual)
})
}
})
t.Run("Error", func(t *testing.T) {
var actual AudienceList
err := json.Unmarshal([]byte("1234"), &actual)
if err == nil {
t.Fatal("expected unmarshal to fail")
}
})
}
func TestAudienceList_Marshal(t *testing.T) {
value := AudienceList{"audience"}
expected := `["audience"]`
actual, err := json.Marshal(value)
if err != nil {
t.Fatal(err)
}
if expected != string(actual) {
t.Errorf("expected marshaled list to be %v, got %v", expected, actual)
}
}
func assertStringListEqual(t *testing.T, expected []string, actual []string) {
t.Helper()
if len(expected) != len(actual) {
t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual))
return
}
for i, v := range expected {
if v != actual[i] {
t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i])
}
return
}
}

View file

@ -56,3 +56,14 @@ func contains(ss []string, q string) bool {
return false
}
// containsAny returns true if any of q is found in ss.
func containsAny(ss []string, q []string) bool {
for _, s := range ss {
if contains(q, s) {
return true
}
}
return false
}