package tree

import (
	"context"
	"crypto/sha256"
	"errors"
	"fmt"
	"io"
	"math"
	"math/rand"
	"sync"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	containerCore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/container"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/pilorama"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/client/netmap"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
	metrics "git.frostfs.info/TrueCloudLab/frostfs-observability/metrics/grpc"
	tracing "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	tracing_grpc "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing/grpc"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	netmapSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	"github.com/panjf2000/ants/v2"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

// ErrNotInContainer is returned when operation could not be performed
// because the node is not included in the container.
var ErrNotInContainer = errors.New("node is not in container")

const defaultSyncWorkerCount = 20

// synchronizeAllTrees synchronizes all the trees of the container. It fetches
// tree IDs from the other container nodes. Returns ErrNotInContainer if the node
// is not included in the container.
func (s *Service) synchronizeAllTrees(ctx context.Context, cid cid.ID) error {
	nodes, pos, err := s.getContainerNodes(cid)
	if err != nil {
		return fmt.Errorf("can't get container nodes: %w", err)
	}

	if pos < 0 {
		return ErrNotInContainer
	}

	nodes = randomizeNodeOrder(nodes, pos)
	if len(nodes) == 0 {
		return nil
	}

	rawCID := make([]byte, sha256.Size)
	cid.Encode(rawCID)

	req := &TreeListRequest{
		Body: &TreeListRequest_Body{
			ContainerId: rawCID,
		},
	}

	err = SignMessage(req, s.key)
	if err != nil {
		return fmt.Errorf("could not sign request: %w", err)
	}

	var resp *TreeListResponse
	var treesToSync []string
	var outErr error

	err = s.forEachNode(ctx, nodes, func(c TreeServiceClient) bool {
		resp, outErr = c.TreeList(ctx, req)
		if outErr != nil {
			return false
		}

		treesToSync = resp.GetBody().GetIds()

		return true
	})
	if err != nil {
		outErr = err
	}

	if outErr != nil {
		return fmt.Errorf("could not fetch tree ID list: %w", outErr)
	}

	for _, tid := range treesToSync {
		h, err := s.forest.TreeLastSyncHeight(ctx, cid, tid)
		if err != nil && !errors.Is(err, pilorama.ErrTreeNotFound) {
			s.log.Warn(logs.TreeCouldNotGetLastSynchronizedHeightForATree,
				zap.Stringer("cid", cid),
				zap.String("tree", tid))
			continue
		}
		newHeight := s.synchronizeTree(ctx, cid, h, tid, nodes)
		if h < newHeight {
			if err := s.forest.TreeUpdateLastSyncHeight(ctx, cid, tid, newHeight); err != nil {
				s.log.Warn(logs.TreeCouldNotUpdateLastSynchronizedHeightForATree,
					zap.Stringer("cid", cid),
					zap.String("tree", tid))
			}
		}
	}

	return nil
}

// SynchronizeTree tries to synchronize log starting from the last stored height.
func (s *Service) SynchronizeTree(ctx context.Context, cid cid.ID, treeID string) error {
	nodes, pos, err := s.getContainerNodes(cid)
	if err != nil {
		return fmt.Errorf("can't get container nodes: %w", err)
	}

	if pos < 0 {
		return ErrNotInContainer
	}

	nodes = randomizeNodeOrder(nodes, pos)
	if len(nodes) == 0 {
		return nil
	}

	s.synchronizeTree(ctx, cid, 0, treeID, nodes)
	return nil
}

// mergeOperationStreams performs merge sort for node operation streams to one stream.
func mergeOperationStreams(streams []chan *pilorama.Move, merged chan<- *pilorama.Move) uint64 {
	defer close(merged)

	ms := make([]*pilorama.Move, len(streams))
	for i := range streams {
		ms[i] = <-streams[i]
	}

	// Merging different node streams shuffles incoming operations like that:
	//
	// x - operation from the stream A
	// o - operation from the stream B
	//
	// --o---o--x--x--x--o---x--x------> t
	//					 ^
	// If all ops have been successfully applied, we must start from the last
	// operation height from the stream B. This height is stored in minStreamedLastHeight.
	var minStreamedLastHeight uint64 = math.MaxUint64

	for {
		var minTimeMoveTime uint64 = math.MaxUint64
		minTimeMoveIndex := -1
		for i, m := range ms {
			if m != nil && minTimeMoveTime > m.Time {
				minTimeMoveTime = m.Time
				minTimeMoveIndex = i
			}
		}

		if minTimeMoveIndex == -1 {
			break
		}

		merged <- ms[minTimeMoveIndex]
		height := ms[minTimeMoveIndex].Time
		if ms[minTimeMoveIndex] = <-streams[minTimeMoveIndex]; ms[minTimeMoveIndex] == nil {
			if minStreamedLastHeight > height {
				minStreamedLastHeight = height
			}
		}
	}

	return minStreamedLastHeight
}

func (s *Service) applyOperationStream(ctx context.Context, cid cid.ID, treeID string,
	operationStream <-chan *pilorama.Move) uint64 {
	errGroup, _ := errgroup.WithContext(ctx)
	const workersCount = 1024
	errGroup.SetLimit(workersCount)

	// We run TreeApply concurrently for the operation batch. Let's consider two operations
	// in the batch m1 and m2 such that m1.Time < m2.Time. The engine may apply m2 and fail
	// on m1. That means the service must start sync from m1.Time in the next iteration and
	// this height is stored in unappliedOperationHeight.
	var unappliedOperationHeight uint64 = math.MaxUint64
	var heightMtx sync.Mutex

	var prev *pilorama.Move
	for m := range operationStream {
		m := m

		// skip already applied op
		if prev != nil && prev.Time == m.Time {
			continue
		}
		prev = m

		errGroup.Go(func() error {
			if err := s.forest.TreeApply(ctx, cid, treeID, m, true); err != nil {
				heightMtx.Lock()
				if m.Time < unappliedOperationHeight {
					unappliedOperationHeight = m.Time
				}
				heightMtx.Unlock()
				return err
			}
			return nil
		})
	}
	_ = errGroup.Wait()
	return unappliedOperationHeight
}

func (s *Service) startStream(ctx context.Context, cid cid.ID, treeID string,
	height uint64, treeClient TreeServiceClient, opsCh chan<- *pilorama.Move) (uint64, error) {
	rawCID := make([]byte, sha256.Size)
	cid.Encode(rawCID)

	for {
		newHeight := height
		req := &GetOpLogRequest{
			Body: &GetOpLogRequest_Body{
				ContainerId: rawCID,
				TreeId:      treeID,
				Height:      newHeight,
			},
		}
		if err := SignMessage(req, s.key); err != nil {
			return 0, err
		}

		c, err := treeClient.GetOpLog(ctx, req)
		if err != nil {
			return 0, fmt.Errorf("can't initialize client: %w", err)
		}
		res, err := c.Recv()
		for ; err == nil; res, err = c.Recv() {
			lm := res.GetBody().GetOperation()
			m := &pilorama.Move{
				Parent: lm.ParentId,
				Child:  lm.ChildId,
			}
			if err := m.Meta.FromBytes(lm.Meta); err != nil {
				return 0, err
			}
			opsCh <- m
		}
		if height == newHeight || err != nil && !errors.Is(err, io.EOF) {
			return newHeight, err
		}
		height = newHeight
	}
}

// synchronizeTree synchronizes operations getting them from different nodes.
// Each available node does stream operations to a separate stream. These streams
// are merged into one big stream ordered by operation time. This way allows to skip
// already applied operation and keep good batching.
// The method returns a height that service should start sync from in the next time.
func (s *Service) synchronizeTree(ctx context.Context, cid cid.ID, from uint64,
	treeID string, nodes []netmapSDK.NodeInfo) uint64 {
	s.log.Debug(logs.TreeSynchronizeTree,
		zap.Stringer("cid", cid),
		zap.String("tree", treeID),
		zap.Uint64("from", from))

	errGroup, egCtx := errgroup.WithContext(ctx)
	const workersCount = 1024
	errGroup.SetLimit(workersCount)

	nodeOperationStreams := make([]chan *pilorama.Move, len(nodes))
	for i := range nodeOperationStreams {
		nodeOperationStreams[i] = make(chan *pilorama.Move)
	}
	merged := make(chan *pilorama.Move)
	var minStreamedLastHeight uint64
	errGroup.Go(func() error {
		minStreamedLastHeight = mergeOperationStreams(nodeOperationStreams, merged)
		return nil
	})
	var minUnappliedHeight uint64
	errGroup.Go(func() error {
		minUnappliedHeight = s.applyOperationStream(ctx, cid, treeID, merged)
		return nil
	})

	for i, n := range nodes {
		i := i
		n := n
		errGroup.Go(func() error {
			height := from
			n.IterateNetworkEndpoints(func(addr string) bool {
				var a network.Address
				if err := a.FromString(addr); err != nil {
					return false
				}

				cc, err := grpc.DialContext(egCtx, a.URIAddr(),
					grpc.WithChainUnaryInterceptor(
						metrics.NewUnaryClientInterceptor(),
						tracing_grpc.NewUnaryClientInteceptor(),
					),
					grpc.WithChainStreamInterceptor(
						metrics.NewStreamClientInterceptor(),
						tracing_grpc.NewStreamClientInterceptor(),
					),
					grpc.WithTransportCredentials(insecure.NewCredentials()))
				if err != nil {
					// Failed to connect, try the next address.
					return false
				}
				defer cc.Close()

				treeClient := NewTreeServiceClient(cc)
				for {
					h, err := s.startStream(egCtx, cid, treeID, from, treeClient, nodeOperationStreams[i])
					if height < h {
						height = h
					}
					if err != nil || h <= height {
						// Error with the response, try the next node.
						return true
					}
				}
			})
			close(nodeOperationStreams[i])
			return nil
		})
	}
	if err := errGroup.Wait(); err != nil {
		s.log.Warn(logs.TreeFailedToRunTreeSynchronizationOverAllNodes, zap.Error(err))
	}

	newHeight := minStreamedLastHeight
	if newHeight > minUnappliedHeight {
		newHeight = minUnappliedHeight
	} else {
		newHeight++
	}
	return newHeight
}

// ErrAlreadySyncing is returned when a service synchronization has already
// been started.
var ErrAlreadySyncing = errors.New("service is being synchronized")

// ErrShuttingDown is returned when the service is shitting down and could not
// accept any calls.
var ErrShuttingDown = errors.New("service is shutting down")

// SynchronizeAll forces tree service to synchronize all the trees according to
// netmap information. Must not be called before Service.Start.
// Returns ErrAlreadySyncing if synchronization has been started and blocked
// by another routine.
// Note: non-blocking operation.
func (s *Service) SynchronizeAll() error {
	select {
	case <-s.closeCh:
		return ErrShuttingDown
	default:
	}

	select {
	case s.syncChan <- struct{}{}:
		return nil
	default:
		return ErrAlreadySyncing
	}
}

func (s *Service) syncLoop(ctx context.Context) {
	for {
		select {
		case <-s.closeCh:
			return
		case <-ctx.Done():
			return
		case <-s.syncChan:
			ctx, span := tracing.StartSpanFromContext(ctx, "TreeService.sync")
			s.log.Debug(logs.TreeSyncingTrees)

			start := time.Now()

			cnrs, err := s.cfg.cnrSource.List()
			if err != nil {
				s.log.Error(logs.TreeCouldNotFetchContainers, zap.Error(err))
				s.metrics.AddSyncDuration(time.Since(start), false)
				span.End()
				break
			}

			newMap, cnrsToSync := s.containersToSync(cnrs)

			s.syncContainers(ctx, cnrsToSync)

			s.removeContainers(ctx, newMap)

			s.log.Debug(logs.TreeTreesHaveBeenSynchronized)

			s.metrics.AddSyncDuration(time.Since(start), true)
			span.End()
		}
		s.initialSyncDone.Store(true)
	}
}

func (s *Service) syncContainers(ctx context.Context, cnrs []cid.ID) {
	ctx, span := tracing.StartSpanFromContext(ctx, "TreeService.syncContainers")
	defer span.End()

	// sync new containers
	var wg sync.WaitGroup
	for _, cnr := range cnrs {
		wg.Add(1)
		cnr := cnr
		err := s.syncPool.Submit(func() {
			defer wg.Done()
			s.log.Debug(logs.TreeSyncingContainerTrees, zap.Stringer("cid", cnr))

			err := s.synchronizeAllTrees(ctx, cnr)
			if err != nil {
				s.log.Error(logs.TreeCouldNotSyncTrees, zap.Stringer("cid", cnr), zap.Error(err))
				return
			}

			s.log.Debug(logs.TreeContainerTreesHaveBeenSynced, zap.Stringer("cid", cnr))
		})
		if err != nil {
			wg.Done()
			s.log.Error(logs.TreeCouldNotQueryTreesForSynchronization,
				zap.Stringer("cid", cnr),
				zap.Error(err))
			if errors.Is(err, ants.ErrPoolClosed) {
				return
			}
		}
	}
	wg.Wait()
}

func (s *Service) removeContainers(ctx context.Context, newContainers map[cid.ID]struct{}) {
	ctx, span := tracing.StartSpanFromContext(ctx, "TreeService.removeContainers")
	defer span.End()

	s.cnrMapMtx.Lock()
	defer s.cnrMapMtx.Unlock()

	var removed []cid.ID
	for cnr := range s.cnrMap {
		if _, ok := newContainers[cnr]; ok {
			continue
		}

		existed, err := containerCore.WasRemoved(s.cnrSource, cnr)
		if err != nil {
			s.log.Error(logs.TreeCouldNotCheckIfContainerExisted,
				zap.Stringer("cid", cnr),
				zap.Error(err))
		} else if existed {
			removed = append(removed, cnr)
		}
	}
	for i := range removed {
		delete(s.cnrMap, removed[i])
	}

	for _, cnr := range removed {
		s.log.Debug(logs.TreeRemovingRedundantTrees, zap.Stringer("cid", cnr))

		err := s.DropTree(ctx, cnr, "")
		if err != nil {
			s.log.Error(logs.TreeCouldNotRemoveRedundantTree,
				zap.Stringer("cid", cnr),
				zap.Error(err))
		}
	}
}

func (s *Service) containersToSync(cnrs []cid.ID) (map[cid.ID]struct{}, []cid.ID) {
	newMap := make(map[cid.ID]struct{}, len(s.cnrMap))
	cnrsToSync := make([]cid.ID, 0, len(cnrs))

	for _, cnr := range cnrs {
		_, pos, err := s.getContainerNodes(cnr)
		if err != nil {
			s.log.Error(logs.TreeCouldNotCalculateContainerNodes,
				zap.Stringer("cid", cnr),
				zap.Error(err))
			continue
		}

		if pos < 0 {
			// node is not included in the container.
			continue
		}

		newMap[cnr] = struct{}{}
		cnrsToSync = append(cnrsToSync, cnr)
	}
	return newMap, cnrsToSync
}

// randomizeNodeOrder shuffles nodes and removes not a `pos` index.
// It is assumed that 0 <= pos < len(nodes).
func randomizeNodeOrder(cnrNodes []netmap.NodeInfo, pos int) []netmap.NodeInfo {
	if len(cnrNodes) == 1 {
		return nil
	}

	nodes := make([]netmap.NodeInfo, len(cnrNodes)-1)
	n := copy(nodes, cnrNodes[:pos])
	copy(nodes[n:], cnrNodes[pos+1:])

	rand.Shuffle(len(nodes), func(i, j int) {
		nodes[i], nodes[j] = nodes[j], nodes[i]
	})
	return nodes
}