From 97fa1183bf5de817a87dea61b629fac95ee22cb4 Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Tue, 27 Sep 2022 15:31:01 +0200 Subject: [PATCH] feat: add WeakStringList type to support lists in aud claim Signed-off-by: Mark Sagi-Kazar --- registry/auth/token/types.go | 55 ++++++++++++++++++++ registry/auth/token/types_test.go | 85 +++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 registry/auth/token/types.go create mode 100644 registry/auth/token/types_test.go diff --git a/registry/auth/token/types.go b/registry/auth/token/types.go new file mode 100644 index 000000000..4559d6daf --- /dev/null +++ b/registry/auth/token/types.go @@ -0,0 +1,55 @@ +package token + +import ( + "encoding/json" + "reflect" +) + +// WeakStringList is a slice of strings that can be deserialized from either a single string value or a list of strings. +type WeakStringList []string + +func (s *WeakStringList) UnmarshalJSON(data []byte) (err error) { + var value interface{} + + if err = json.Unmarshal(data, &value); err != nil { + return err + } + + switch v := value.(type) { + case string: + *s = []string{v} + + case []string: + *s = v + + case []interface{}: + var ss []string + + for _, vv := range v { + vs, ok := vv.(string) + if !ok { + return &json.UnsupportedTypeError{ + Type: reflect.TypeOf(vv), + } + } + + ss = append(ss, vs) + } + + *s = ss + + case nil: + return nil + + default: + return &json.UnsupportedTypeError{ + Type: reflect.TypeOf(v), + } + } + + return +} + +func (s WeakStringList) MarshalJSON() (b []byte, err error) { + return json.Marshal([]string(s)) +} diff --git a/registry/auth/token/types_test.go b/registry/auth/token/types_test.go new file mode 100644 index 000000000..283d86a80 --- /dev/null +++ b/registry/auth/token/types_test.go @@ -0,0 +1,85 @@ +package token + +import ( + "encoding/json" + "testing" +) + +func TestWeakStringList_Unmarshal(t *testing.T) { + t.Run("OK", func(t *testing.T) { + testCases := []struct { + value string + expected WeakStringList + }{ + { + value: `"audience"`, + expected: WeakStringList{"audience"}, + }, + { + value: `["audience1", "audience2"]`, + expected: WeakStringList{"audience1", "audience2"}, + }, + { + value: `null`, + expected: nil, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run("", func(t *testing.T) { + var actual WeakStringList + + err := json.Unmarshal([]byte(testCase.value), &actual) + if err != nil { + t.Fatal(err) + } + + assertStringListEqual(t, testCase.expected, actual) + }) + } + }) + + t.Run("Error", func(t *testing.T) { + var actual WeakStringList + + err := json.Unmarshal([]byte("1234"), &actual) + if err == nil { + t.Fatal("expected unmarshal to fail") + } + }) +} + +func TestWeakStringList_Marshal(t *testing.T) { + value := WeakStringList{"audience"} + + expected := `["audience"]` + + actual, err := json.Marshal(value) + if err != nil { + t.Fatal(err) + } + + if expected != string(actual) { + t.Errorf("expected marshaled list to be %v, got %v", expected, actual) + } +} + +func assertStringListEqual(t *testing.T, expected []string, actual []string) { + t.Helper() + + if len(expected) != len(actual) { + t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual)) + + return + } + + for i, v := range expected { + if v != actual[i] { + t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i]) + } + + return + } +}