package api

import (
	"context"
	"io"
	"net/http"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/metrics"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
)

func RequestTypeFromAPI(api string) metrics.RequestType {
	switch api {
	case "Options", "HeadObject", "HeadBucket":
		return metrics.HEADRequest
	case "CreateMultipartUpload", "UploadPartCopy", "UploadPart", "CompleteMultipartUpload",
		"PutObjectACL", "PutObjectTagging", "CopyObject", "PutObjectRetention", "PutObjectLegalHold",
		"PutObject", "PutBucketCors", "PutBucketACL", "PutBucketLifecycle", "PutBucketEncryption",
		"PutBucketPolicy", "PutBucketObjectLockConfig", "PutBucketTagging", "PutBucketVersioning",
		"PutBucketNotification", "CreateBucket", "PostObject":
		return metrics.PUTRequest
	case "ListObjectParts", "ListMultipartUploads", "ListObjectsV2M", "ListObjectsV2", "ListBucketVersions",
		"ListObjectsV1", "ListBuckets":
		return metrics.LISTRequest
	case "GetObjectACL", "GetObjectTagging", "SelectObjectContent", "GetObjectRetention", "getobjectlegalhold",
		"GetObjectAttributes", "GetObject", "GetBucketLocation", "GetBucketPolicy",
		"GetBucketLifecycle", "GetBucketEncryption", "GetBucketCors", "GetBucketACL",
		"GetBucketWebsite", "GetBucketAccelerate", "GetBucketRequestPayment", "GetBucketLogging",
		"GetBucketReplication", "GetBucketTagging", "GetBucketObjectLockConfig",
		"GetBucketVersioning", "GetBucketNotification", "ListenBucketNotification":
		return metrics.GETRequest
	case "AbortMultipartUpload", "DeleteObjectTagging", "DeleteObject", "DeleteBucketCors",
		"DeleteBucketWebsite", "DeleteBucketTagging", "DeleteMultipleObjects", "DeleteBucketPolicy",
		"DeleteBucketLifecycle", "DeleteBucketEncryption", "DeleteBucket":
		return metrics.DELETERequest
	default:
		return metrics.UNKNOWNRequest
	}
}

type (
	UsersStat interface {
		Update(user, bucket, cnrID string, reqType int, in, out uint64)
	}

	readCounter struct {
		io.ReadCloser
		countBytes uint64
	}

	writeCounter struct {
		http.ResponseWriter
		countBytes uint64
	}

	responseWrapper struct {
		sync.Once
		http.ResponseWriter

		statusCode int
		startTime  time.Time
	}
)

const systemPath = "/system"

//var apiStatMetrics = metrics.newApiStatMetrics()

// CIDResolveFunc is a func to resolve CID in Stats handler.
type CIDResolveFunc func(ctx context.Context, reqInfo *ReqInfo) (cnrID string)

// Stats is a handler that update metrics.
func Stats(f http.HandlerFunc, resolveCID CIDResolveFunc, appMetrics *metrics.AppMetrics) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		reqInfo := GetReqInfo(r.Context())

		appMetrics.Statistic().CurrentS3RequestsInc(reqInfo.API)
		defer appMetrics.Statistic().CurrentS3RequestsDec(reqInfo.API)

		in := &readCounter{ReadCloser: r.Body}
		out := &writeCounter{ResponseWriter: w}

		r.Body = in

		statsWriter := &responseWrapper{
			ResponseWriter: out,
			startTime:      time.Now(),
		}

		f(statsWriter, r)

		// Time duration in secs since the call started.
		// We don't need to do nanosecond precision here
		// simply for the fact that it is not human-readable.
		durationSecs := time.Since(statsWriter.startTime).Seconds()

		user := resolveUser(r.Context())
		cnrID := resolveCID(r.Context(), reqInfo)
		appMetrics.Update(user, reqInfo.BucketName, cnrID, RequestTypeFromAPI(reqInfo.API), in.countBytes, out.countBytes)

		code := statsWriter.statusCode
		// A successful request has a 2xx response code
		successReq := code >= http.StatusOK && code < http.StatusMultipleChoices
		if !strings.HasSuffix(r.URL.Path, systemPath) {
			appMetrics.Statistic().TotalS3RequestsInc(reqInfo.API)
			if !successReq && code != 0 {
				appMetrics.Statistic().TotalS3ErrorsInc(reqInfo.API)
			}
		}

		if r.Method == http.MethodGet {
			// Increment the prometheus http request response histogram with appropriate label
			appMetrics.Statistic().RequestDurationsUpdate(reqInfo.API, durationSecs)
		}

		appMetrics.Statistic().TotalInputBytesAdd(in.countBytes)
		appMetrics.Statistic().TotalOutputBytesAdd(out.countBytes)
	}
}

func resolveUser(ctx context.Context) string {
	user := "anon"
	if bd, ok := ctx.Value(BoxData).(*accessbox.Box); ok && bd != nil && bd.Gate != nil && bd.Gate.BearerToken != nil {
		user = bearer.ResolveIssuer(*bd.Gate.BearerToken).String()
	}
	return user
}

// WriteHeader -- writes http status code.
func (w *responseWrapper) WriteHeader(code int) {
	w.Do(func() {
		w.statusCode = code
		w.ResponseWriter.WriteHeader(code)
	})
}

// Flush -- calls the underlying Flush.
func (w *responseWrapper) Flush() {
	if f, ok := w.ResponseWriter.(http.Flusher); ok {
		f.Flush()
	}
}

func (w *writeCounter) Flush() {
	if f, ok := w.ResponseWriter.(http.Flusher); ok {
		f.Flush()
	}
}

func (w *writeCounter) Write(p []byte) (int, error) {
	n, err := w.ResponseWriter.Write(p)
	atomic.AddUint64(&w.countBytes, uint64(n))
	return n, err
}

func (r *readCounter) Read(p []byte) (int, error) {
	n, err := r.ReadCloser.Read(p)
	atomic.AddUint64(&r.countBytes, uint64(n))
	return n, err
}