package pilorama

import (
	"context"
	"errors"
	"fmt"
	"sort"
	"strings"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/shard/mode"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
)

var errInvalidKeyFormat = errors.New("invalid format: key must be cid and treeID")

// memoryForest represents multiple replicating trees sharing a single storage.
type memoryForest struct {
	// treeMap maps tree identifier (container ID + name) to the replicated log.
	treeMap map[string]*memoryTree
}

var _ Forest = (*memoryForest)(nil)

// NewMemoryForest creates new empty forest.
// TODO: this function will eventually be removed and is here for debugging.
func NewMemoryForest() ForestStorage {
	return &memoryForest{
		treeMap: make(map[string]*memoryTree),
	}
}

// TreeMove implements the Forest interface.
func (f *memoryForest) TreeMove(_ context.Context, d CIDDescriptor, treeID string, op *Move) (*Move, error) {
	if !d.checkValid() {
		return nil, ErrInvalidCIDDescriptor
	}

	fullID := d.CID.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		s = newMemoryTree()
		f.treeMap[fullID] = s
	}

	op.Time = s.timestamp(d.Position, d.Size)
	if op.Child == RootID {
		op.Child = s.findSpareID()
	}

	lm := s.do(op)
	s.operations = append(s.operations, lm)
	return &lm.Move, nil
}

// TreeAddByPath implements the Forest interface.
func (f *memoryForest) TreeAddByPath(_ context.Context, d CIDDescriptor, treeID string, attr string, path []string, m []KeyValue) ([]Move, error) {
	if !d.checkValid() {
		return nil, ErrInvalidCIDDescriptor
	}
	if !isAttributeInternal(attr) {
		return nil, ErrNotPathAttribute
	}

	fullID := d.CID.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		s = newMemoryTree()
		f.treeMap[fullID] = s
	}

	i, node := s.getPathPrefix(attr, path)
	lm := make([]Move, len(path)-i+1)
	for j := i; j < len(path); j++ {
		op := s.do(&Move{
			Parent: node,
			Meta: Meta{
				Time:  s.timestamp(d.Position, d.Size),
				Items: []KeyValue{{Key: attr, Value: []byte(path[j])}},
			},
			Child: s.findSpareID(),
		})
		lm[j-i] = op.Move
		node = op.Child
		s.operations = append(s.operations, op)
	}

	mCopy := make([]KeyValue, len(m))
	copy(mCopy, m)
	op := s.do(&Move{
		Parent: node,
		Meta: Meta{
			Time:  s.timestamp(d.Position, d.Size),
			Items: mCopy,
		},
		Child: s.findSpareID(),
	})
	s.operations = append(s.operations, op)
	lm[len(lm)-1] = op.Move
	return lm, nil
}

// TreeApply implements the Forest interface.
func (f *memoryForest) TreeApply(_ context.Context, cnr cid.ID, treeID string, op *Move, _ bool) error {
	fullID := cnr.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		s = newMemoryTree()
		f.treeMap[fullID] = s
	}

	return s.Apply(op)
}

func (f *memoryForest) Init() error {
	return nil
}

func (f *memoryForest) Open(context.Context, mode.Mode) error {
	return nil
}

func (f *memoryForest) SetMode(mode.Mode) error {
	return nil
}

func (f *memoryForest) Close() error {
	return nil
}
func (f *memoryForest) SetParentID(string) {}

// TreeGetByPath implements the Forest interface.
func (f *memoryForest) TreeGetByPath(_ context.Context, cid cid.ID, treeID string, attr string, path []string, latest bool) ([]Node, error) {
	if !isAttributeInternal(attr) {
		return nil, ErrNotPathAttribute
	}

	fullID := cid.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		return nil, ErrTreeNotFound
	}

	return s.getByPath(attr, path, latest), nil
}

// TreeGetMeta implements the Forest interface.
func (f *memoryForest) TreeGetMeta(_ context.Context, cid cid.ID, treeID string, nodeID Node) (Meta, Node, error) {
	fullID := cid.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		return Meta{}, 0, ErrTreeNotFound
	}

	return s.infoMap[nodeID].Meta, s.infoMap[nodeID].Parent, nil
}

// TreeGetChildren implements the Forest interface.
func (f *memoryForest) TreeGetChildren(_ context.Context, cid cid.ID, treeID string, nodeID Node) ([]NodeInfo, error) {
	fullID := cid.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		return nil, ErrTreeNotFound
	}

	children := s.tree.getChildren(nodeID)
	res := make([]NodeInfo, 0, len(children))
	for _, childID := range children {
		res = append(res, NodeInfo{
			ID:       childID,
			Meta:     s.infoMap[childID].Meta,
			ParentID: s.infoMap[childID].Parent,
		})
	}
	return res, nil
}

// TreeGetOpLog implements the pilorama.Forest interface.
func (f *memoryForest) TreeGetOpLog(_ context.Context, cid cid.ID, treeID string, height uint64) (Move, error) {
	fullID := cid.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		return Move{}, ErrTreeNotFound
	}

	n := sort.Search(len(s.operations), func(i int) bool {
		return s.operations[i].Time >= height
	})
	if n == len(s.operations) {
		return Move{}, nil
	}
	return s.operations[n].Move, nil
}

// TreeDrop implements the pilorama.Forest interface.
func (f *memoryForest) TreeDrop(_ context.Context, cid cid.ID, treeID string) error {
	cidStr := cid.String()
	if treeID == "" {
		for k := range f.treeMap {
			if strings.HasPrefix(k, cidStr) {
				delete(f.treeMap, k)
			}
		}
	} else {
		fullID := cidStr + "/" + treeID
		_, ok := f.treeMap[fullID]
		if !ok {
			return ErrTreeNotFound
		}
		delete(f.treeMap, fullID)
	}
	return nil
}

// TreeList implements the pilorama.Forest interface.
func (f *memoryForest) TreeList(_ context.Context, cid cid.ID) ([]string, error) {
	var res []string
	cidStr := cid.EncodeToString()

	for k := range f.treeMap {
		cidAndTree := strings.Split(k, "/")
		if cidAndTree[0] != cidStr {
			continue
		}

		res = append(res, cidAndTree[1])
	}

	return res, nil
}

func (f *memoryForest) TreeHeight(_ context.Context, cid cid.ID, treeID string) (uint64, error) {
	fullID := cid.EncodeToString() + "/" + treeID
	tree, ok := f.treeMap[fullID]
	if !ok {
		return 0, ErrTreeNotFound
	}
	return tree.operations[len(tree.operations)-1].Time, nil
}

// TreeExists implements the pilorama.Forest interface.
func (f *memoryForest) TreeExists(_ context.Context, cid cid.ID, treeID string) (bool, error) {
	fullID := cid.EncodeToString() + "/" + treeID
	_, ok := f.treeMap[fullID]
	return ok, nil
}

// TreeUpdateLastSyncHeight implements the pilorama.Forest interface.
func (f *memoryForest) TreeUpdateLastSyncHeight(_ context.Context, cid cid.ID, treeID string, height uint64) error {
	fullID := cid.EncodeToString() + "/" + treeID
	t, ok := f.treeMap[fullID]
	if !ok {
		return ErrTreeNotFound
	}
	t.syncHeight = height
	return nil
}

// TreeLastSyncHeight implements the pilorama.Forest interface.
func (f *memoryForest) TreeLastSyncHeight(_ context.Context, cid cid.ID, treeID string) (uint64, error) {
	fullID := cid.EncodeToString() + "/" + treeID
	t, ok := f.treeMap[fullID]
	if !ok {
		return 0, ErrTreeNotFound
	}
	return t.syncHeight, nil
}

// TreeListTrees implements Forest.
func (f *memoryForest) TreeListTrees(_ context.Context, prm TreeListTreesPrm) (*TreeListTreesResult, error) {
	batchSize := prm.BatchSize
	if batchSize <= 0 {
		batchSize = treeListTreesBatchSizeDefault
	}
	tmpSlice := make([]string, 0, len(f.treeMap))
	for k := range f.treeMap {
		tmpSlice = append(tmpSlice, k)
	}
	sort.Strings(tmpSlice)
	var idx int
	if len(prm.NextPageToken) > 0 {
		last := string(prm.NextPageToken)
		idx, _ = sort.Find(len(tmpSlice), func(i int) int {
			return -1 * strings.Compare(tmpSlice[i], last)
		})
		if idx == len(tmpSlice) {
			return &TreeListTreesResult{}, nil
		}
		if tmpSlice[idx] == last {
			idx++
		}
	}

	var result TreeListTreesResult
	for idx < len(tmpSlice) {
		cidAndTree := strings.Split(tmpSlice[idx], "/")
		if len(cidAndTree) != 2 {
			return nil, errInvalidKeyFormat
		}
		var contID cid.ID
		if err := contID.DecodeString(cidAndTree[0]); err != nil {
			return nil, fmt.Errorf("invalid format: %w", err)
		}

		result.Items = append(result.Items, ContainerIDTreeID{
			CID:    contID,
			TreeID: cidAndTree[1],
		})

		if len(result.Items) == batchSize {
			result.NextPageToken = []byte(tmpSlice[idx])
			break
		}
		idx++
	}
	return &result, nil
}

// TreeApplyStream implements ForestStorage.
func (f *memoryForest) TreeApplyStream(ctx context.Context, cnr cid.ID, treeID string, source <-chan *Move) error {
	fullID := cnr.String() + "/" + treeID
	s, ok := f.treeMap[fullID]
	if !ok {
		s = newMemoryTree()
		f.treeMap[fullID] = s
	}

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case m, ok := <-source:
			if !ok {
				return nil
			}
			if e := s.Apply(m); e != nil {
				return e
			}
		}
	}
}