package api import ( "fmt" "net/http" ) type HandlerFilters struct { filters []Filter defaultHandler http.Handler } type Filter struct { queries []Pair headers []Pair allowedQueries map[string]struct{} noQueries bool h http.Handler } type Pair struct { Key string Value string } func NewHandlerFilter() *HandlerFilters { return &HandlerFilters{} } func NewFilter() *Filter { return &Filter{} } func (hf *HandlerFilters) Add(filter *Filter) *HandlerFilters { hf.filters = append(hf.filters, *filter) return hf } // HeadersMatch adds a matcher for header values. // It accepts a sequence of key/value pairs. Values may define variables. // Panics if number of parameters is not even. // Supports only exact matching. // If the value is an empty string, it will match any value if the key is set. func (f *Filter) HeadersMatch(pairs ...string) *Filter { length := len(pairs) if length%2 != 0 { panic(fmt.Errorf("filter headers: number of parameters must be multiple of 2, got %v", pairs)) } for i := 0; i < length; i += 2 { f.headers = append(f.headers, Pair{ Key: pairs[i], Value: pairs[i+1], }) } return f } // Headers is similar to HeadersMatch but accept only header keys, set value to empty string internally. func (f *Filter) Headers(headers ...string) *Filter { for _, header := range headers { f.headers = append(f.headers, Pair{ Key: header, Value: "", }) } return f } func (f *Filter) Handler(handler http.HandlerFunc) *Filter { f.h = handler return f } // QueriesMatch adds a matcher for URL query values. // It accepts a sequence of key/value pairs. Values may define variables. // Panics if number of parameters is not even. // Supports only exact matching. // If the value is an empty string, it will match any value if the key is set. func (f *Filter) QueriesMatch(pairs ...string) *Filter { length := len(pairs) if length%2 != 0 { panic(fmt.Errorf("filter headers: number of parameters must be multiple of 2, got %v", pairs)) } for i := 0; i < length; i += 2 { f.queries = append(f.queries, Pair{ Key: pairs[i], Value: pairs[i+1], }) } return f } // Queries is similar to QueriesMatch but accept only query keys, set value to empty string internally. func (f *Filter) Queries(queries ...string) *Filter { for _, query := range queries { f.queries = append(f.queries, Pair{ Key: query, Value: "", }) } return f } // NoQueries sets flag indicating that request shouldn't have query parameters. func (f *Filter) NoQueries() *Filter { f.noQueries = true return f } // AllowedQueries adds query parameter keys that may be present in request. func (f *Filter) AllowedQueries(queries ...string) *Filter { f.allowedQueries = make(map[string]struct{}, len(queries)) for _, query := range queries { f.allowedQueries[query] = struct{}{} } return f } func (hf *HandlerFilters) DefaultHandler(handler http.HandlerFunc) *HandlerFilters { hf.defaultHandler = handler return hf } func (hf *HandlerFilters) ServeHTTP(w http.ResponseWriter, r *http.Request) { if handler := hf.match(r); handler != nil { handler.ServeHTTP(w, r) return } hf.defaultHandler.ServeHTTP(w, r) } func (hf *HandlerFilters) match(r *http.Request) http.Handler { queries := r.URL.Query() LOOP: for _, filter := range hf.filters { if filter.noQueries && len(queries) > 0 { continue } if len(filter.allowedQueries) > 0 { for key := range queries { if _, ok := filter.allowedQueries[key]; !ok { continue LOOP } } } for _, header := range filter.headers { hdrVals := r.Header.Values(header.Key) if len(hdrVals) == 0 || header.Value != "" && header.Value != hdrVals[0] { continue LOOP } } for _, query := range filter.queries { queryVal := queries.Get(query.Key) if !queries.Has(query.Key) || query.Value != "" && query.Value != queryVal { continue LOOP } } return filter.h } return nil }