[#174] Fix router filter query matching

Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
This commit is contained in:
Denis Kirillov 2023-07-18 09:25:50 +03:00
parent 6e3595e35b
commit 73ed3f7782
2 changed files with 80 additions and 4 deletions

View file

@ -111,6 +111,15 @@ func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilte
} }
func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if handler := hf.match(r); handler != nil {
handler.ServeHTTP(w, r)
return
}
hf.defaultHandler.ServeHTTP(w, r)
}
func (hf *HandlerFilters) match(r *http.Request) http.Handler {
LOOP: LOOP:
for _, filter := range hf.filters { for _, filter := range hf.filters {
for _, header := range filter.headers { for _, header := range filter.headers {
@ -121,13 +130,12 @@ LOOP:
} }
for _, query := range filter.queries { for _, query := range filter.queries {
queryVal := r.URL.Query().Get(query.Key) queryVal := r.URL.Query().Get(query.Key)
if !r.URL.Query().Has(query.Key) || queryVal != "" && query.Value != queryVal { if !r.URL.Query().Has(query.Key) || query.Value != "" && query.Value != queryVal {
continue LOOP continue LOOP
} }
} }
filter.h.ServeHTTP(w, r) return filter.h
return
} }
hf.defaultHandler.ServeHTTP(w, r) return nil
} }

68
api/router_filter_test.go Normal file
View file

@ -0,0 +1,68 @@
package api
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestFilter(t *testing.T) {
key1, val1 := "key1", "val1"
key2, val2 := "key2", "val2"
key3, val3 := "key3", "val3"
anyVal := ""
notNilHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
t.Run("queries", func(t *testing.T) {
f := NewHandlerFilter().
Add(NewFilter().
QueriesMatch(key1, val1, key2, anyVal).
Queries(key3).
Handler(notNilHandler))
r, err := http.NewRequest(http.MethodGet, "https://localhost:8084", nil)
require.NoError(t, err)
query := make(url.Values)
query.Set(key1, val1)
query.Set(key2, val2)
query.Set(key3, val3)
r.URL.RawQuery = query.Encode()
h := f.match(r)
require.NotNil(t, h)
query.Set(key1, val2)
r.URL.RawQuery = query.Encode()
h = f.match(r)
require.Nil(t, h)
})
t.Run("headers", func(t *testing.T) {
f := NewHandlerFilter().
Add(NewFilter().
HeadersMatch(key1, val1, key2, anyVal).
Headers(key3).
Handler(notNilHandler))
r, err := http.NewRequest(http.MethodGet, "https://localhost:8084", nil)
require.NoError(t, err)
r.Header.Set(key1, val1)
r.Header.Set(key2, val2)
r.Header.Set(key3, val3)
h := f.match(r)
require.NotNil(t, h)
r.Header.Set(key1, val2)
h = f.match(r)
require.Nil(t, h)
})
}