package metrics

import (
	"sync"
	"sync/atomic"

	"github.com/prometheus/client_golang/prometheus"
)

type (
	// httpAPIStats holds statistics information about
	// the API given in the requests.
	httpAPIStats struct {
		apiStats map[string]int
		sync.RWMutex
	}

	// httpStats holds statistics information about
	// HTTP requests made by all clients.
	httpStats struct {
		currentS3Requests httpAPIStats
		totalS3Requests   httpAPIStats
		totalS3Errors     httpAPIStats

		totalInputBytes  uint64
		totalOutputBytes uint64

		currentS3RequestsDesc *prometheus.Desc
		totalS3RequestsDesc   *prometheus.Desc
		totalS3ErrorsDesc     *prometheus.Desc
		txBytesTotalDesc      *prometheus.Desc
		rxBytesTotalDesc      *prometheus.Desc
	}

	APIStatMetrics struct {
		stats                *httpStats
		httpRequestsDuration *prometheus.HistogramVec
	}
)

const (
	statisticSubsystem = "statistic"
)

const (
	requestsSecondsMetric = "requests_seconds"
	requestsCurrentMetric = "requests_current"
	requestsTotalMetric   = "requests_total"
	errorsTotalMetric     = "errors_total"
	txBytesTotalMetric    = "tx_bytes_total"
	rxBytesTotalMetric    = "rx_bytes_total"
)

func newAPIStatMetrics() *APIStatMetrics {
	histogramDesc := appMetricsDesc[statisticSubsystem][requestsSecondsMetric]

	return &APIStatMetrics{
		stats: newHTTPStats(),
		httpRequestsDuration: mustNewHistogramVec(histogramDesc,
			[]float64{.05, .1, .25, .5, 1, 2.5, 5, 10}),
	}
}

func (a *APIStatMetrics) CurrentS3RequestsInc(api string) {
	if a == nil {
		return
	}

	a.stats.currentS3Requests.Inc(api)
}

func (a *APIStatMetrics) CurrentS3RequestsDec(api string) {
	if a == nil {
		return
	}
	a.stats.currentS3Requests.Dec(api)
}

func (a *APIStatMetrics) TotalS3RequestsInc(api string) {
	if a == nil {
		return
	}
	a.stats.totalS3Requests.Inc(api)
}

func (a *APIStatMetrics) TotalS3ErrorsInc(api string) {
	if a == nil {
		return
	}
	a.stats.totalS3Errors.Inc(api)
}

func (a *APIStatMetrics) TotalInputBytesAdd(val uint64) {
	if a == nil {
		return
	}
	atomic.AddUint64(&a.stats.totalInputBytes, val)
}

func (a *APIStatMetrics) TotalOutputBytesAdd(val uint64) {
	if a == nil {
		return
	}
	atomic.AddUint64(&a.stats.totalOutputBytes, val)
}

func (a *APIStatMetrics) RequestDurationsUpdate(api string, durationSecs float64) {
	if a == nil {
		return
	}
	a.httpRequestsDuration.With(prometheus.Labels{"api": api}).Observe(durationSecs)
}

func (a *APIStatMetrics) Describe(ch chan<- *prometheus.Desc) {
	if a == nil {
		return
	}
	a.stats.Describe(ch)
	a.httpRequestsDuration.Describe(ch)
}

func (a *APIStatMetrics) Collect(ch chan<- prometheus.Metric) {
	if a == nil {
		return
	}
	a.stats.Collect(ch)
	a.httpRequestsDuration.Collect(ch)
}

func newHTTPStats() *httpStats {
	return &httpStats{
		currentS3RequestsDesc: newDesc(appMetricsDesc[statisticSubsystem][requestsCurrentMetric]),
		totalS3RequestsDesc:   newDesc(appMetricsDesc[statisticSubsystem][requestsTotalMetric]),
		totalS3ErrorsDesc:     newDesc(appMetricsDesc[statisticSubsystem][errorsTotalMetric]),
		txBytesTotalDesc:      newDesc(appMetricsDesc[statisticSubsystem][txBytesTotalMetric]),
		rxBytesTotalDesc:      newDesc(appMetricsDesc[statisticSubsystem][rxBytesTotalMetric]),
	}
}

func (s *httpStats) Describe(desc chan<- *prometheus.Desc) {
	desc <- s.currentS3RequestsDesc
	desc <- s.totalS3RequestsDesc
	desc <- s.totalS3ErrorsDesc
	desc <- s.txBytesTotalDesc
	desc <- s.rxBytesTotalDesc
}

func (s *httpStats) Collect(ch chan<- prometheus.Metric) {
	for api, value := range s.currentS3Requests.Load() {
		ch <- prometheus.MustNewConstMetric(s.currentS3RequestsDesc, prometheus.CounterValue, float64(value), api)
	}

	for api, value := range s.totalS3Requests.Load() {
		ch <- prometheus.MustNewConstMetric(s.totalS3RequestsDesc, prometheus.CounterValue, float64(value), api)
	}

	for api, value := range s.totalS3Errors.Load() {
		ch <- prometheus.MustNewConstMetric(s.totalS3ErrorsDesc, prometheus.CounterValue, float64(value), api)
	}

	// Network Sent/Received Bytes (Outbound)
	ch <- prometheus.MustNewConstMetric(s.txBytesTotalDesc, prometheus.CounterValue, float64(s.getInputBytes()))
	ch <- prometheus.MustNewConstMetric(s.rxBytesTotalDesc, prometheus.CounterValue, float64(s.getOutputBytes()))
}

// Inc increments the api stats counter.
func (s *httpAPIStats) Inc(api string) {
	if s == nil {
		return
	}
	s.Lock()
	defer s.Unlock()
	if s.apiStats == nil {
		s.apiStats = make(map[string]int)
	}
	s.apiStats[api]++
}

// Dec increments the api stats counter.
func (s *httpAPIStats) Dec(api string) {
	if s == nil {
		return
	}
	s.Lock()
	defer s.Unlock()
	if val, ok := s.apiStats[api]; ok && val > 0 {
		s.apiStats[api]--
	}
}

// Load returns the recorded stats.
func (s *httpAPIStats) Load() map[string]int {
	s.Lock()
	defer s.Unlock()
	var apiStats = make(map[string]int, len(s.apiStats))
	for k, v := range s.apiStats {
		apiStats[k] = v
	}
	return apiStats
}

func (s *httpStats) getInputBytes() uint64 {
	return atomic.LoadUint64(&s.totalInputBytes)
}

func (s *httpStats) getOutputBytes() uint64 {
	return atomic.LoadUint64(&s.totalOutputBytes)
}