From f1ef3fb351ce0659c7540f5863397d9d3012e153 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 7 Oct 2021 15:48:11 -0700 Subject: [PATCH] Add GetBool(s string) bool to URI type. --- kms/uri/uri.go | 10 ++++++++++ kms/uri/uri_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 44271e74..3a1b8981 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -95,6 +95,16 @@ func (u *URI) Get(key string) string { return v } +// GetBool returns true if a given key has the value "true". It will returns +// false otherwise. +func (u *URI) GetBool(key string) bool { + v := u.Values.Get(key) + if v == "" { + v = u.URL.Query().Get(key) + } + return strings.EqualFold(v, "true") +} + // GetEncoded returns the first value in the uri with the given key, it will // return empty nil if that field is not present or is empty. If the return // value is hex encoded it will decode it and return it. diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index c2e0a9fe..01fbad0f 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -212,6 +212,40 @@ func TestURI_Get(t *testing.T) { } } +func TestURI_GetBool(t *testing.T) { + mustParse := func(s string) *URI { + u, err := Parse(s) + if err != nil { + t.Fatal(err) + } + return u + } + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want bool + }{ + {"true", mustParse("azurekms:name=foo;vault=bar;hsm=true"), args{"hsm"}, true}, + {"TRUE", mustParse("azurekms:name=foo;vault=bar;hsm=TRUE"), args{"hsm"}, true}, + {"tRUe query", mustParse("azurekms:name=foo;vault=bar?hsm=tRUe"), args{"hsm"}, true}, + {"false", mustParse("azurekms:name=foo;vault=bar;hsm=false"), args{"hsm"}, false}, + {"false query", mustParse("azurekms:name=foo;vault=bar?hsm=false"), args{"hsm"}, false}, + {"empty", mustParse("azurekms:name=foo;vault=bar;hsm=?bar=true"), args{"hsm"}, false}, + {"missing", mustParse("azurekms:name=foo;vault=bar"), args{"hsm"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.uri.GetBool(tt.args.key); got != tt.want { + t.Errorf("URI.GetBool() = %v, want %v", got, tt.want) + } + }) + } +} + func TestURI_GetEncoded(t *testing.T) { mustParse := func(s string) *URI { u, err := Parse(s)