package testutil import ( "bytes" "fmt" "io" "net/http" "net/url" "sort" "strings" ) // RequestResponseMap is an ordered mapping from Requests to Responses type RequestResponseMap []RequestResponseMapping // RequestResponseMapping defines a Response to be sent in response to a given // Request type RequestResponseMapping struct { Request Request Response Response } // Request is a simplified http.Request object type Request struct { // Method is the http method of the request, for example GET Method string // Route is the http route of this request Route string // QueryParams are the query parameters of this request QueryParams map[string][]string // Body is the byte contents of the http request Body []byte // Headers are the header for this request Headers http.Header } func (r Request) String() string { queryString := "" if len(r.QueryParams) > 0 { keys := make([]string, 0, len(r.QueryParams)) queryParts := make([]string, 0, len(r.QueryParams)) for k := range r.QueryParams { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { for _, val := range r.QueryParams[k] { queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val))) } } queryString = "?" + strings.Join(queryParts, "&") } var headers []string if len(r.Headers) > 0 { var headerKeys []string for k := range r.Headers { headerKeys = append(headerKeys, k) } sort.Strings(headerKeys) for _, k := range headerKeys { for _, val := range r.Headers[k] { headers = append(headers, fmt.Sprintf("%s:%s", k, val)) } } } return fmt.Sprintf("%s %s%s\n%s\n%s", r.Method, r.Route, queryString, headers, r.Body) } // Response is a simplified http.Response object type Response struct { // Statuscode is the http status code of the Response StatusCode int // Headers are the http headers of this Response Headers http.Header // Body is the response body Body []byte } // testHandler is an http.Handler with a defined mapping from Request to an // ordered list of Response objects type testHandler struct { responseMap map[string][]Response } // NewHandler returns a new test handler that responds to defined requests // with specified responses // Each time a Request is received, the next Response is returned in the // mapping, until no Responses are defined, at which point a 404 is sent back func NewHandler(requestResponseMap RequestResponseMap) http.Handler { responseMap := make(map[string][]Response) for _, mapping := range requestResponseMap { responses, ok := responseMap[mapping.Request.String()] if ok { responseMap[mapping.Request.String()] = append(responses, mapping.Response) } else { responseMap[mapping.Request.String()] = []Response{mapping.Response} } } return &testHandler{responseMap: responseMap} } func (app *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() requestBody, _ := io.ReadAll(r.Body) request := Request{ Method: r.Method, Route: r.URL.Path, QueryParams: r.URL.Query(), Body: requestBody, Headers: make(map[string][]string), } // Add headers of interest here for k, v := range r.Header { if k == "If-None-Match" { request.Headers[k] = v } } responses, ok := app.responseMap[request.String()] if !ok || len(responses) == 0 { http.NotFound(w, r) return } response := responses[0] app.responseMap[request.String()] = responses[1:] responseHeader := w.Header() for k, v := range response.Headers { responseHeader[k] = v } w.WriteHeader(response.StatusCode) if _, err := io.Copy(w, bytes.NewReader(response.Body)); err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } }