package middleware

import (
	"context"
	"io"
	"net/http"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/metrics"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	"go.uber.org/zap"
)

type (
	readCounter struct {
		io.ReadCloser
		countBytes uint64
	}

	writeCounter struct {
		http.ResponseWriter
		countBytes uint64
	}

	responseWrapper struct {
		sync.Once
		http.ResponseWriter

		statusCode int
		startTime  time.Time
	}

	MetricsSettings interface {
		ResolveNamespaceAlias(namespace string) string
	}

	// ContainerIDResolveFunc is a func to resolve container id by name.
	ContainerIDResolveFunc func(ctx context.Context, bucket string) (cid.ID, error)

	// cidResolveFunc is a func to resolve CID in Stats handler.
	cidResolveFunc func(ctx context.Context, reqInfo *ReqInfo) (cnrID string)
)

const systemPath = "/system"

// Metrics wraps http handler for api with basic statistics collection.
func Metrics(log *zap.Logger, resolveBucket ContainerIDResolveFunc, appMetrics *metrics.AppMetrics, settings MetricsSettings) Func {
	return func(h http.Handler) http.Handler {
		return stats(h.ServeHTTP, resolveCID(log, resolveBucket), appMetrics, settings)
	}
}

// Stats is a handler that update metrics.
func stats(f http.HandlerFunc, resolveCID cidResolveFunc, appMetrics *metrics.AppMetrics, settings MetricsSettings) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		reqInfo := GetReqInfo(r.Context())

		appMetrics.Statistic().CurrentS3RequestsInc(reqInfo.API)
		defer appMetrics.Statistic().CurrentS3RequestsDec(reqInfo.API)

		in := &readCounter{ReadCloser: r.Body}
		out := &writeCounter{ResponseWriter: w}

		r.Body = in

		statsWriter := &responseWrapper{
			ResponseWriter: out,
			startTime:      time.Now(),
		}

		f(statsWriter, r)

		// Time duration in secs since the call started.
		// We don't need to do nanosecond precision here
		// simply for the fact that it is not human-readable.
		durationSecs := time.Since(statsWriter.startTime).Seconds()

		cnrID := resolveCID(r.Context(), reqInfo)
		appMetrics.UsersAPIStats().Update(reqInfo.User, reqInfo.BucketName, cnrID, settings.ResolveNamespaceAlias(reqInfo.Namespace),
			requestTypeFromAPI(reqInfo.API), in.countBytes, out.countBytes)

		code := statsWriter.statusCode
		// A successful request has a 2xx response code
		successReq := code >= http.StatusOK && code < http.StatusMultipleChoices
		if !strings.HasSuffix(r.URL.Path, systemPath) {
			appMetrics.Statistic().TotalS3RequestsInc(reqInfo.API)
			if !successReq && code != 0 {
				appMetrics.Statistic().TotalS3ErrorsInc(reqInfo.API)
			}
		}

		// Increment the prometheus http request response histogram with appropriate label
		appMetrics.Statistic().RequestDurationsUpdate(reqInfo.API, durationSecs)

		appMetrics.Statistic().TotalInputBytesAdd(in.countBytes)
		appMetrics.Statistic().TotalOutputBytesAdd(out.countBytes)
	}
}

func requestTypeFromAPI(api string) metrics.RequestType {
	switch api {
	case OptionsBucketOperation, OptionsObjectOperation, HeadObjectOperation, HeadBucketOperation:
		return metrics.HEADRequest
	case CreateMultipartUploadOperation, UploadPartCopyOperation, UploadPartOperation, CompleteMultipartUploadOperation,
		PutObjectACLOperation, PutObjectTaggingOperation, CopyObjectOperation, PutObjectRetentionOperation, PutObjectLegalHoldOperation,
		PutObjectOperation, PutBucketCorsOperation, PutBucketACLOperation, PutBucketLifecycleOperation, PutBucketEncryptionOperation,
		PutBucketPolicyOperation, PutBucketObjectLockConfigOperation, PutBucketTaggingOperation, PutBucketVersioningOperation,
		PutBucketNotificationOperation, CreateBucketOperation, PostObjectOperation:
		return metrics.PUTRequest
	case ListPartsOperation, ListMultipartUploadsOperation, ListObjectsV2MOperation, ListObjectsV2Operation,
		ListObjectsV1Operation, ListBucketsOperation:
		return metrics.LISTRequest
	case GetObjectACLOperation, GetObjectTaggingOperation, SelectObjectContentOperation, GetObjectRetentionOperation, GetObjectLegalHoldOperation,
		GetObjectAttributesOperation, GetObjectOperation, GetBucketLocationOperation, GetBucketPolicyOperation,
		GetBucketLifecycleOperation, GetBucketEncryptionOperation, GetBucketCorsOperation, GetBucketACLOperation,
		GetBucketWebsiteOperation, GetBucketAccelerateOperation, GetBucketRequestPaymentOperation, GetBucketLoggingOperation,
		GetBucketReplicationOperation, GetBucketTaggingOperation, GetBucketObjectLockConfigOperation,
		GetBucketVersioningOperation, GetBucketNotificationOperation, ListenBucketNotificationOperation:
		return metrics.GETRequest
	case AbortMultipartUploadOperation, DeleteObjectTaggingOperation, DeleteObjectOperation, DeleteBucketCorsOperation,
		DeleteBucketWebsiteOperation, DeleteBucketTaggingOperation, DeleteMultipleObjectsOperation, DeleteBucketPolicyOperation,
		DeleteBucketLifecycleOperation, DeleteBucketEncryptionOperation, DeleteBucketOperation:
		return metrics.DELETERequest
	default:
		return metrics.UNKNOWNRequest
	}
}

// resolveCID forms CIDResolveFunc using BucketResolveFunc.
func resolveCID(log *zap.Logger, resolveContainerID ContainerIDResolveFunc) cidResolveFunc {
	return func(ctx context.Context, reqInfo *ReqInfo) (cnrID string) {
		if reqInfo.BucketName == "" || reqInfo.API == CreateBucketOperation || reqInfo.API == "" {
			return ""
		}

		containerID, err := resolveContainerID(ctx, reqInfo.BucketName)
		if err != nil {
			reqLogOrDefault(ctx, log).Debug(logs.FailedToResolveCID, zap.Error(err))
			return ""
		}

		return containerID.EncodeToString()
	}
}

// WriteHeader -- writes http status code.
func (w *responseWrapper) WriteHeader(code int) {
	w.Do(func() {
		w.statusCode = code
		w.ResponseWriter.WriteHeader(code)
	})
}

// Flush -- calls the underlying Flush.
func (w *responseWrapper) Flush() {
	if f, ok := w.ResponseWriter.(http.Flusher); ok {
		f.Flush()
	}
}

func (w *writeCounter) Flush() {
	if f, ok := w.ResponseWriter.(http.Flusher); ok {
		f.Flush()
	}
}

func (w *writeCounter) Write(p []byte) (int, error) {
	n, err := w.ResponseWriter.Write(p)
	atomic.AddUint64(&w.countBytes, uint64(n))
	return n, err
}

func (r *readCounter) Read(p []byte) (int, error) {
	n, err := r.ReadCloser.Read(p)
	atomic.AddUint64(&r.countBytes, uint64(n))
	return n, err
}