diff --git a/api/router_filter.go b/api/router_filter.go index e7e6f47..cdd87b3 100644 --- a/api/router_filter.go +++ b/api/router_filter.go @@ -111,6 +111,15 @@ func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilte } func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if handler := hf.match(r); handler != nil { + handler.ServeHTTP(w, r) + return + } + + hf.defaultHandler.ServeHTTP(w, r) +} + +func (hf *HandlerFilters) match(r *http.Request) http.Handler { LOOP: for _, filter := range hf.filters { for _, header := range filter.headers { @@ -121,13 +130,12 @@ LOOP: } for _, query := range filter.queries { queryVal := r.URL.Query().Get(query.Key) - if !r.URL.Query().Has(query.Key) || queryVal != "" && query.Value != queryVal { + if !r.URL.Query().Has(query.Key) || query.Value != "" && query.Value != queryVal { continue LOOP } } - filter.h.ServeHTTP(w, r) - return + return filter.h } - hf.defaultHandler.ServeHTTP(w, r) + return nil } diff --git a/api/router_filter_test.go b/api/router_filter_test.go new file mode 100644 index 0000000..2511229 --- /dev/null +++ b/api/router_filter_test.go @@ -0,0 +1,68 @@ +package api + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFilter(t *testing.T) { + key1, val1 := "key1", "val1" + key2, val2 := "key2", "val2" + key3, val3 := "key3", "val3" + + anyVal := "" + + notNilHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) + + t.Run("queries", func(t *testing.T) { + f := NewHandlerFilter(). + Add(NewFilter(). + QueriesMatch(key1, val1, key2, anyVal). + Queries(key3). + Handler(notNilHandler)) + + r, err := http.NewRequest(http.MethodGet, "https://localhost:8084", nil) + require.NoError(t, err) + + query := make(url.Values) + query.Set(key1, val1) + query.Set(key2, val2) + query.Set(key3, val3) + r.URL.RawQuery = query.Encode() + + h := f.match(r) + require.NotNil(t, h) + + query.Set(key1, val2) + r.URL.RawQuery = query.Encode() + + h = f.match(r) + require.Nil(t, h) + }) + + t.Run("headers", func(t *testing.T) { + f := NewHandlerFilter(). + Add(NewFilter(). + HeadersMatch(key1, val1, key2, anyVal). + Headers(key3). + Handler(notNilHandler)) + + r, err := http.NewRequest(http.MethodGet, "https://localhost:8084", nil) + require.NoError(t, err) + + r.Header.Set(key1, val1) + r.Header.Set(key2, val2) + r.Header.Set(key3, val3) + + h := f.match(r) + require.NotNil(t, h) + + r.Header.Set(key1, val2) + + h = f.match(r) + require.Nil(t, h) + }) +}