package context

import (
	"context"
	"errors"
	"net"
	"net/http"
	"strings"
	"sync"
	"time"

	"github.com/distribution/distribution/v3/internal/uuid"
	"github.com/gorilla/mux"
	log "github.com/sirupsen/logrus"
)

// Common errors used with this package.
var (
	ErrNoRequestContext        = errors.New("no http request in context")
	ErrNoResponseWriterContext = errors.New("no http response in context")
)

func parseIP(ipStr string) net.IP {
	ip := net.ParseIP(ipStr)
	if ip == nil {
		log.Warnf("invalid remote IP address: %q", ipStr)
	}
	return ip
}

// RemoteAddr extracts the remote address of the request, taking into
// account proxy headers.
func RemoteAddr(r *http.Request) string {
	if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
		remoteAddr, _, _ := strings.Cut(prior, ",")
		remoteAddr = strings.Trim(remoteAddr, " ")
		if parseIP(remoteAddr) != nil {
			return remoteAddr
		}
	}
	// X-Real-Ip is less supported, but worth checking in the
	// absence of X-Forwarded-For
	if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
		if parseIP(realIP) != nil {
			return realIP
		}
	}

	return r.RemoteAddr
}

// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
	addr := RemoteAddr(r)

	// Try parsing it as "IP:port"
	if ip, _, err := net.SplitHostPort(addr); err == nil {
		return ip
	}

	return addr
}

// WithRequest places the request on the context. The context of the request
// is assigned a unique id, available at "http.request.id". The request itself
// is available at "http.request". Other common attributes are available under
// the prefix "http.request.". If a request is already present on the context,
// this method will panic.
func WithRequest(ctx context.Context, r *http.Request) context.Context {
	if ctx.Value("http.request") != nil {
		// NOTE(stevvooe): This needs to be considered a programming error. It
		// is unlikely that we'd want to have more than one request in
		// context.
		panic("only one request per context")
	}

	return &httpRequestContext{
		Context:   ctx,
		startedAt: time.Now(),
		id:        uuid.Generate().String(),
		r:         r,
	}
}

// GetRequest returns the http request in the given context. Returns
// ErrNoRequestContext if the context does not have an http request associated
// with it.
func GetRequest(ctx context.Context) (*http.Request, error) {
	if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok {
		return r, nil
	}
	return nil, ErrNoRequestContext
}

// GetRequestID attempts to resolve the current request id, if possible. An
// error is return if it is not available on the context.
func GetRequestID(ctx context.Context) string {
	return GetStringValue(ctx, "http.request.id")
}

// WithResponseWriter returns a new context and response writer that makes
// interesting response statistics available within the context.
func WithResponseWriter(ctx context.Context, w http.ResponseWriter) (context.Context, http.ResponseWriter) {
	irw := instrumentedResponseWriter{
		ResponseWriter: w,
		Context:        ctx,
	}
	return &irw, &irw
}

// GetResponseWriter returns the http.ResponseWriter from the provided
// context. If not present, ErrNoResponseWriterContext is returned. The
// returned instance provides instrumentation in the context.
func GetResponseWriter(ctx context.Context) (http.ResponseWriter, error) {
	v := ctx.Value("http.response")

	rw, ok := v.(http.ResponseWriter)
	if !ok || rw == nil {
		return nil, ErrNoResponseWriterContext
	}

	return rw, nil
}

// getVarsFromRequest let's us change request vars implementation for testing
// and maybe future changes.
var getVarsFromRequest = mux.Vars

// WithVars extracts gorilla/mux vars and makes them available on the returned
// context. Variables are available at keys with the prefix "vars.". For
// example, if looking for the variable "name", it can be accessed as
// "vars.name". Implementations that are accessing values need not know that
// the underlying context is implemented with gorilla/mux vars.
func WithVars(ctx context.Context, r *http.Request) context.Context {
	return &muxVarsContext{
		Context: ctx,
		vars:    getVarsFromRequest(r),
	}
}

// GetRequestLogger returns a logger that contains fields from the request in
// the current context. If the request is not available in the context, no
// fields will display. Request loggers can safely be pushed onto the context.
func GetRequestLogger(ctx context.Context) Logger {
	return GetLogger(ctx,
		"http.request.id",
		"http.request.method",
		"http.request.host",
		"http.request.uri",
		"http.request.referer",
		"http.request.useragent",
		"http.request.remoteaddr",
		"http.request.contenttype")
}

// GetResponseLogger reads the current response stats and builds a logger.
// Because the values are read at call time, pushing a logger returned from
// this function on the context will lead to missing or invalid data. Only
// call this at the end of a request, after the response has been written.
func GetResponseLogger(ctx context.Context) Logger {
	l := getLogrusLogger(ctx,
		"http.response.written",
		"http.response.status",
		"http.response.contenttype")

	duration := Since(ctx, "http.request.startedat")

	if duration > 0 {
		l = l.WithField("http.response.duration", duration.String())
	}

	return l
}

// httpRequestContext makes information about a request available to context.
type httpRequestContext struct {
	context.Context

	startedAt time.Time
	id        string
	r         *http.Request
}

// Value returns a keyed element of the request for use in the context. To get
// the request itself, query "request". For other components, access them as
// "request.<component>". For example, r.RequestURI
func (ctx *httpRequestContext) Value(key interface{}) interface{} {
	if keyStr, ok := key.(string); ok {
		switch keyStr {
		case "http.request":
			return ctx.r
		case "http.request.uri":
			return ctx.r.RequestURI
		case "http.request.remoteaddr":
			return RemoteAddr(ctx.r)
		case "http.request.method":
			return ctx.r.Method
		case "http.request.host":
			return ctx.r.Host
		case "http.request.referer":
			referer := ctx.r.Referer()
			if referer != "" {
				return referer
			}
		case "http.request.useragent":
			return ctx.r.UserAgent()
		case "http.request.id":
			return ctx.id
		case "http.request.startedat":
			return ctx.startedAt
		case "http.request.contenttype":
			if ct := ctx.r.Header.Get("Content-Type"); ct != "" {
				return ct
			}
		default:
			// no match; fall back to standard behavior below
		}
	}

	return ctx.Context.Value(key)
}

type muxVarsContext struct {
	context.Context
	vars map[string]string
}

func (ctx *muxVarsContext) Value(key interface{}) interface{} {
	if keyStr, ok := key.(string); ok {
		if keyStr == "vars" {
			return ctx.vars
		}
		// TODO(thaJeztah): this considers "vars.FOO" and "FOO" to be equal.
		// We need to check if that's intentional (could be a bug).
		if v, ok := ctx.vars[strings.TrimPrefix(keyStr, "vars.")]; ok {
			return v
		}
	}

	return ctx.Context.Value(key)
}

// instrumentedResponseWriter provides response writer information in a
// context. This variant is only used in the case where CloseNotifier is not
// implemented by the parent ResponseWriter.
type instrumentedResponseWriter struct {
	http.ResponseWriter
	context.Context

	mu      sync.Mutex
	status  int
	written int64
}

func (irw *instrumentedResponseWriter) Write(p []byte) (n int, err error) {
	n, err = irw.ResponseWriter.Write(p)

	irw.mu.Lock()
	irw.written += int64(n)

	// Guess the likely status if not set.
	if irw.status == 0 {
		irw.status = http.StatusOK
	}

	irw.mu.Unlock()

	return
}

func (irw *instrumentedResponseWriter) WriteHeader(status int) {
	irw.ResponseWriter.WriteHeader(status)

	irw.mu.Lock()
	irw.status = status
	irw.mu.Unlock()
}

func (irw *instrumentedResponseWriter) Flush() {
	if flusher, ok := irw.ResponseWriter.(http.Flusher); ok {
		flusher.Flush()
	}
}

func (irw *instrumentedResponseWriter) Value(key interface{}) interface{} {
	if keyStr, ok := key.(string); ok {
		switch keyStr {
		case "http.response":
			return irw
		case "http.response.written":
			irw.mu.Lock()
			defer irw.mu.Unlock()
			return irw.written
		case "http.response.status":
			irw.mu.Lock()
			defer irw.mu.Unlock()
			return irw.status
		case "http.response.contenttype":
			if ct := irw.Header().Get("Content-Type"); ct != "" {
				return ct
			}
		default:
			// no match; fall back to standard behavior below
		}
	}

	return irw.Context.Value(key)
}