package object

import (
	"bytes"
	"cmp"
	"context"
	"crypto/ecdsa"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"slices"
	"sync"

	internalclient "git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-cli/internal/client"
	"git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-cli/internal/commonflags"
	"git.frostfs.info/TrueCloudLab/frostfs-node/cmd/frostfs-cli/internal/key"
	commonCmd "git.frostfs.info/TrueCloudLab/frostfs-node/cmd/internal/common"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/policy"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	netmapSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"github.com/spf13/cobra"
	"golang.org/x/sync/errgroup"
)

const (
	verifyPresenceAllFlag       = "verify-presence-all"
	preferInternalAddressesFlag = "prefer-internal-addresses"
)

var (
	errNoAvailableEndpoint    = errors.New("failed to create client: no available endpoint")
	errMalformedComplexObject = errors.New("object consists of EC and non EC parts")
)

type phyObject struct {
	containerID               cid.ID
	objectID                  oid.ID
	storedOnAllContainerNodes bool
	ecHeader                  *ecHeader
}

type ecHeader struct {
	index  uint32
	parent oid.ID
}

type objectPlacement struct {
	requiredNodes  []netmapSDK.NodeInfo
	confirmedNodes []netmapSDK.NodeInfo
}

type objectNodesResult struct {
	errors     []error
	placements map[oid.ID]objectPlacement
}

type ObjNodesDataObject struct {
	ObjectID         string   `json:"object_id"`
	RequiredNodes    []string `json:"required_nodes,omitempty"`
	ConfirmedNodes   []string `json:"confirmed_nodes,omitempty"`
	ECParentObjectID *string  `json:"ec_parent_object_id,omitempty"`
	ECIndex          *uint32  `json:"ec_index,omitempty"`
}

type objNodesResultJSON struct {
	ObjectID    string               `json:"object_id"`
	DataObjects []ObjNodesDataObject `json:"data_objects,omitempty"`
	Errors      []string             `json:"errors,omitempty"`
}

var objectNodesCmd = &cobra.Command{
	Use:   "nodes",
	Short: "List of nodes where the object is stored",
	Long: `List of nodes where the object should be stored and where it is actually stored.
	Lock objects must exist on all nodes of the container.
	For complex and EC objects, a node is considered to store an object if the node stores at least one part of the complex object or one chunk of the EC object.
	By default, the actual storage of the object is checked only on the nodes that should store the object. To check all nodes, use the flag --verify-presence-all.`,
	Run: objectNodes,
}

func initObjectNodesCmd() {
	commonflags.Init(objectNodesCmd)

	flags := objectNodesCmd.Flags()

	flags.String(commonflags.CIDFlag, "", commonflags.CIDFlagUsage)
	_ = objectGetCmd.MarkFlagRequired(commonflags.CIDFlag)

	flags.String(commonflags.OIDFlag, "", commonflags.OIDFlagUsage)
	_ = objectGetCmd.MarkFlagRequired(commonflags.OIDFlag)

	flags.Bool(verifyPresenceAllFlag, false, "Verify the actual presence of the object on all netmap nodes.")
	flags.Bool(commonflags.JSON, false, "Print information about the object placement as json.")
	flags.Bool(preferInternalAddressesFlag, false, "Use internal addresses first to get object info.")
}

func objectNodes(cmd *cobra.Command, _ []string) {
	var cnrID cid.ID
	var objID oid.ID
	readObjectAddress(cmd, &cnrID, &objID)

	pk := key.GetOrGenerate(cmd)
	cli := internalclient.GetSDKClientByFlag(cmd, pk, commonflags.RPC)

	objects := getPhyObjects(cmd, cnrID, objID, cli, pk)

	placementPolicy, netmap := getPlacementPolicyAndNetmap(cmd, cnrID, cli)

	result := getRequiredPlacement(cmd, objects, placementPolicy, netmap)

	getActualPlacement(cmd, netmap, pk, objects, result)

	printPlacement(cmd, objID, objects, result)
}

func getPhyObjects(cmd *cobra.Command, cnrID cid.ID, objID oid.ID, cli *client.Client, pk *ecdsa.PrivateKey) []phyObject {
	var addrObj oid.Address
	addrObj.SetContainer(cnrID)
	addrObj.SetObject(objID)

	var prmHead internalclient.HeadObjectPrm
	prmHead.SetClient(cli)
	prmHead.SetAddress(addrObj)
	prmHead.SetRawFlag(true)

	Prepare(cmd, &prmHead)
	readSession(cmd, &prmHead, pk, cnrID, objID)

	res, err := internalclient.HeadObject(cmd.Context(), prmHead)
	if err == nil {
		obj := phyObject{
			containerID: cnrID,
			objectID:    objID,
			storedOnAllContainerNodes: res.Header().Type() == objectSDK.TypeLock ||
				res.Header().Type() == objectSDK.TypeTombstone ||
				len(res.Header().Children()) > 0,
		}
		if res.Header().ECHeader() != nil {
			obj.ecHeader = &ecHeader{
				index:  res.Header().ECHeader().Index(),
				parent: res.Header().ECHeader().Parent(),
			}
		}
		return []phyObject{obj}
	}

	var errSplitInfo *objectSDK.SplitInfoError
	if errors.As(err, &errSplitInfo) {
		return getComplexObjectParts(cmd, cnrID, objID, cli, prmHead, errSplitInfo)
	}

	var ecInfoError *objectSDK.ECInfoError
	if errors.As(err, &ecInfoError) {
		return getECObjectChunks(cmd, cnrID, objID, ecInfoError)
	}
	commonCmd.ExitOnErr(cmd, "failed to get object info: %w", err)
	return nil
}

func getComplexObjectParts(cmd *cobra.Command, cnrID cid.ID, objID oid.ID, cli *client.Client, prmHead internalclient.HeadObjectPrm, errSplitInfo *objectSDK.SplitInfoError) []phyObject {
	members := getCompexObjectMembers(cmd, cnrID, objID, cli, prmHead, errSplitInfo)
	return flattenComplexMembersIfECContainer(cmd, cnrID, members, prmHead)
}

func getCompexObjectMembers(cmd *cobra.Command, cnrID cid.ID, objID oid.ID, cli *client.Client, prmHead internalclient.HeadObjectPrm, errSplitInfo *objectSDK.SplitInfoError) []oid.ID {
	splitInfo := errSplitInfo.SplitInfo()

	if members, ok := tryGetSplitMembersByLinkingObject(cmd, splitInfo, prmHead, cnrID); ok {
		return members
	}

	if members, ok := tryGetSplitMembersBySplitID(cmd, splitInfo, cli, cnrID); ok {
		return members
	}

	return tryRestoreChainInReverse(cmd, splitInfo, prmHead, cli, cnrID, objID)
}

func flattenComplexMembersIfECContainer(cmd *cobra.Command, cnrID cid.ID, members []oid.ID, prmHead internalclient.HeadObjectPrm) []phyObject {
	result := make([]phyObject, 0, len(members))
	var hasNonEC, hasEC bool
	var resultGuard sync.Mutex

	if len(members) == 0 {
		return result
	}

	prmHead.SetRawFlag(true) // to get an error instead of whole object

	eg, egCtx := errgroup.WithContext(cmd.Context())
	for idx := range len(members) {
		partObjID := members[idx]

		eg.Go(func() error {
			partHeadPrm := prmHead
			var partAddr oid.Address
			partAddr.SetContainer(cnrID)
			partAddr.SetObject(partObjID)
			partHeadPrm.SetAddress(partAddr)

			obj, err := internalclient.HeadObject(egCtx, partHeadPrm)
			if err != nil {
				var ecInfoError *objectSDK.ECInfoError
				if errors.As(err, &ecInfoError) {
					resultGuard.Lock()
					defer resultGuard.Unlock()
					result = append(result, getECObjectChunks(cmd, cnrID, partObjID, ecInfoError)...)
					hasEC = true
					return nil
				}
				return err
			}

			if obj.Header().Type() != objectSDK.TypeRegular {
				commonCmd.ExitOnErr(cmd, "failed to flatten parts of complex object: %w", fmt.Errorf("object '%s' with type '%s' is not supported as part of complex object", partAddr, obj.Header().Type()))
			}

			if len(obj.Header().Children()) > 0 {
				// linking object is not data object, so skip it
				return nil
			}

			resultGuard.Lock()
			defer resultGuard.Unlock()
			result = append(result, phyObject{
				containerID: cnrID,
				objectID:    partObjID,
			})
			hasNonEC = true

			return nil
		})
	}

	commonCmd.ExitOnErr(cmd, "failed to flatten parts of complex object: %w", eg.Wait())
	if hasEC && hasNonEC {
		commonCmd.ExitOnErr(cmd, "failed to flatten parts of complex object: %w", errMalformedComplexObject)
	}
	return result
}

func getECObjectChunks(cmd *cobra.Command, cnrID cid.ID, objID oid.ID, errECInfo *objectSDK.ECInfoError) []phyObject {
	ecInfo := errECInfo.ECInfo()
	result := make([]phyObject, 0, len(ecInfo.Chunks))
	for _, ch := range ecInfo.Chunks {
		var chID oid.ID
		err := chID.ReadFromV2(ch.ID)
		if err != nil {
			commonCmd.ExitOnErr(cmd, "failed to read EC chunk ID %w", err)
			return nil
		}
		result = append(result, phyObject{
			containerID: cnrID,
			objectID:    chID,
			ecHeader: &ecHeader{
				index:  ch.Index,
				parent: objID,
			},
		})
	}
	return result
}

func getPlacementPolicyAndNetmap(cmd *cobra.Command, cnrID cid.ID, cli *client.Client) (placementPolicy netmapSDK.PlacementPolicy, netmap *netmapSDK.NetMap) {
	eg, egCtx := errgroup.WithContext(cmd.Context())
	eg.Go(func() (e error) {
		placementPolicy, e = getPlacementPolicy(egCtx, cnrID, cli)
		return
	})
	eg.Go(func() (e error) {
		netmap, e = getNetMap(egCtx, cli)
		return
	})
	commonCmd.ExitOnErr(cmd, "rpc error: %w", eg.Wait())
	return
}

func getPlacementPolicy(ctx context.Context, cnrID cid.ID, cli *client.Client) (netmapSDK.PlacementPolicy, error) {
	prm := internalclient.GetContainerPrm{
		Client: cli,
		ClientParams: client.PrmContainerGet{
			ContainerID: &cnrID,
		},
	}

	res, err := internalclient.GetContainer(ctx, prm)
	if err != nil {
		return netmapSDK.PlacementPolicy{}, err
	}

	return res.Container().PlacementPolicy(), nil
}

func getNetMap(ctx context.Context, cli *client.Client) (*netmapSDK.NetMap, error) {
	var prm internalclient.NetMapSnapshotPrm
	prm.SetClient(cli)

	res, err := internalclient.NetMapSnapshot(ctx, prm)
	if err != nil {
		return nil, err
	}
	nm := res.NetMap()
	return &nm, nil
}

func getRequiredPlacement(cmd *cobra.Command, objects []phyObject, placementPolicy netmapSDK.PlacementPolicy, netmap *netmapSDK.NetMap) *objectNodesResult {
	if policy.IsECPlacement(placementPolicy) {
		return getECRequiredPlacement(cmd, objects, placementPolicy, netmap)
	}
	return getReplicaRequiredPlacement(cmd, objects, placementPolicy, netmap)
}

func getReplicaRequiredPlacement(cmd *cobra.Command, objects []phyObject, placementPolicy netmapSDK.PlacementPolicy, netmap *netmapSDK.NetMap) *objectNodesResult {
	result := &objectNodesResult{
		placements: make(map[oid.ID]objectPlacement),
	}
	placementBuilder := placement.NewNetworkMapBuilder(netmap)
	for _, object := range objects {
		placement, err := placementBuilder.BuildPlacement(object.containerID, &object.objectID, placementPolicy)
		commonCmd.ExitOnErr(cmd, "failed to get required placement for object: %w", err)
		for repIdx, rep := range placement {
			numOfReplicas := placementPolicy.ReplicaDescriptor(repIdx).NumberOfObjects()
			var nodeIdx uint32
			for _, n := range rep {
				if !object.storedOnAllContainerNodes && nodeIdx == numOfReplicas {
					break
				}

				op := result.placements[object.objectID]
				op.requiredNodes = append(op.requiredNodes, n)
				result.placements[object.objectID] = op

				nodeIdx++
			}
		}
	}

	return result
}

func getECRequiredPlacement(cmd *cobra.Command, objects []phyObject, placementPolicy netmapSDK.PlacementPolicy, netmap *netmapSDK.NetMap) *objectNodesResult {
	result := &objectNodesResult{
		placements: make(map[oid.ID]objectPlacement),
	}
	for _, object := range objects {
		getECRequiredPlacementInternal(cmd, object, placementPolicy, netmap, result)
	}
	return result
}

func getECRequiredPlacementInternal(cmd *cobra.Command, object phyObject, placementPolicy netmapSDK.PlacementPolicy, netmap *netmapSDK.NetMap, result *objectNodesResult) {
	placementObjectID := object.objectID
	if object.ecHeader != nil {
		placementObjectID = object.ecHeader.parent
	}
	placementBuilder := placement.NewNetworkMapBuilder(netmap)
	placement, err := placementBuilder.BuildPlacement(object.containerID, &placementObjectID, placementPolicy)
	commonCmd.ExitOnErr(cmd, "failed to get required placement: %w", err)

	for _, vector := range placement {
		if object.storedOnAllContainerNodes {
			for _, node := range vector {
				op := result.placements[object.objectID]
				op.requiredNodes = append(op.requiredNodes, node)
				result.placements[object.objectID] = op
			}
			continue
		}

		if object.ecHeader != nil {
			chunkIdx := int(object.ecHeader.index)
			nodeIdx := chunkIdx % len(vector)
			node := vector[nodeIdx]

			op := result.placements[object.objectID]
			op.requiredNodes = append(op.requiredNodes, node)
			result.placements[object.objectID] = op
		}
	}
}

func getActualPlacement(cmd *cobra.Command, netmap *netmapSDK.NetMap, pk *ecdsa.PrivateKey, objects []phyObject, result *objectNodesResult) {
	resultMtx := &sync.Mutex{}

	candidates := getNodesToCheckObjectExistance(cmd, netmap, result)

	eg, egCtx := errgroup.WithContext(cmd.Context())
	for _, cand := range candidates {
		eg.Go(func() error {
			cli, err := createClient(egCtx, cmd, cand, pk)
			if err != nil {
				resultMtx.Lock()
				defer resultMtx.Unlock()
				result.errors = append(result.errors, fmt.Errorf("failed to connect to node %s: %w", hex.EncodeToString(cand.PublicKey()), err))
				return nil
			}

			for _, object := range objects {
				eg.Go(func() error {
					stored, err := isObjectStoredOnNode(egCtx, cmd, object.containerID, object.objectID, cli, pk)
					resultMtx.Lock()
					defer resultMtx.Unlock()
					if err == nil && stored {
						op := result.placements[object.objectID]
						op.confirmedNodes = append(op.confirmedNodes, cand)
						result.placements[object.objectID] = op
					}
					if err != nil {
						result.errors = append(result.errors, fmt.Errorf("failed to check object %s existence on node %s: %w", object.objectID.EncodeToString(), hex.EncodeToString(cand.PublicKey()), err))
					}
					return nil
				})
			}
			return nil
		})
	}

	commonCmd.ExitOnErr(cmd, "failed to get actual placement: %w", eg.Wait())
}

func getNodesToCheckObjectExistance(cmd *cobra.Command, netmap *netmapSDK.NetMap, result *objectNodesResult) []netmapSDK.NodeInfo {
	checkAllNodes, _ := cmd.Flags().GetBool(verifyPresenceAllFlag)
	if checkAllNodes {
		return netmap.Nodes()
	}
	var nodes []netmapSDK.NodeInfo
	visited := make(map[uint64]struct{})
	for _, p := range result.placements {
		for _, node := range p.requiredNodes {
			if _, ok := visited[node.Hash()]; !ok {
				nodes = append(nodes, node)
				visited[node.Hash()] = struct{}{}
			}
		}
	}
	return nodes
}

func createClient(ctx context.Context, cmd *cobra.Command, candidate netmapSDK.NodeInfo, pk *ecdsa.PrivateKey) (*client.Client, error) {
	var cli *client.Client
	var addresses []string
	if preferInternal, _ := cmd.Flags().GetBool(preferInternalAddressesFlag); preferInternal {
		candidate.IterateNetworkEndpoints(func(s string) bool {
			addresses = append(addresses, s)
			return false
		})
		addresses = append(addresses, candidate.ExternalAddresses()...)
	} else {
		addresses = append(addresses, candidate.ExternalAddresses()...)
		candidate.IterateNetworkEndpoints(func(s string) bool {
			addresses = append(addresses, s)
			return false
		})
	}

	var lastErr error
	for _, address := range addresses {
		var networkAddr network.Address
		lastErr = networkAddr.FromString(address)
		if lastErr != nil {
			continue
		}
		cli, lastErr = internalclient.GetSDKClient(ctx, cmd, pk, networkAddr)
		if lastErr == nil {
			break
		}
	}
	if lastErr != nil {
		return nil, lastErr
	}
	if cli == nil {
		return nil, errNoAvailableEndpoint
	}
	return cli, nil
}

func isObjectStoredOnNode(ctx context.Context, cmd *cobra.Command, cnrID cid.ID, objID oid.ID, cli *client.Client, pk *ecdsa.PrivateKey) (bool, error) {
	var addrObj oid.Address
	addrObj.SetContainer(cnrID)
	addrObj.SetObject(objID)

	var prmHead internalclient.HeadObjectPrm
	prmHead.SetClient(cli)
	prmHead.SetAddress(addrObj)

	Prepare(cmd, &prmHead)
	prmHead.SetTTL(1)
	readSession(cmd, &prmHead, pk, cnrID, objID)

	res, err := internalclient.HeadObject(ctx, prmHead)
	if err == nil && res != nil {
		return true, nil
	}
	var notFound *apistatus.ObjectNotFound
	var removed *apistatus.ObjectAlreadyRemoved
	if errors.As(err, &notFound) || errors.As(err, &removed) {
		return false, nil
	}
	return false, err
}

func printPlacement(cmd *cobra.Command, objID oid.ID, objects []phyObject, result *objectNodesResult) {
	normilizeObjectNodesResult(objects, result)
	if json, _ := cmd.Flags().GetBool(commonflags.JSON); json {
		printObjectNodesAsJSON(cmd, objID, objects, result)
	} else {
		printObjectNodesAsText(cmd, objID, objects, result)
	}
}

func normilizeObjectNodesResult(objects []phyObject, result *objectNodesResult) {
	slices.SortFunc(objects, func(lhs, rhs phyObject) int {
		if lhs.ecHeader == nil && rhs.ecHeader == nil {
			return bytes.Compare(lhs.objectID[:], rhs.objectID[:])
		}
		if lhs.ecHeader == nil {
			return -1
		}
		if rhs.ecHeader == nil {
			return 1
		}
		if lhs.ecHeader.parent == rhs.ecHeader.parent {
			return cmp.Compare(lhs.ecHeader.index, rhs.ecHeader.index)
		}
		return bytes.Compare(lhs.ecHeader.parent[:], rhs.ecHeader.parent[:])
	})
	for _, obj := range objects {
		op := result.placements[obj.objectID]
		slices.SortFunc(op.confirmedNodes, func(lhs, rhs netmapSDK.NodeInfo) int {
			return bytes.Compare(lhs.PublicKey(), rhs.PublicKey())
		})
		slices.SortFunc(op.requiredNodes, func(lhs, rhs netmapSDK.NodeInfo) int {
			return bytes.Compare(lhs.PublicKey(), rhs.PublicKey())
		})
		result.placements[obj.objectID] = op
	}
}

func printObjectNodesAsText(cmd *cobra.Command, objID oid.ID, objects []phyObject, result *objectNodesResult) {
	fmt.Fprintf(cmd.OutOrStdout(), "Object %s stores payload in %d data objects:\n", objID.EncodeToString(), len(objects))

	for _, object := range objects {
		fmt.Fprintf(cmd.OutOrStdout(), "- %s\n", object.objectID)
		if object.ecHeader != nil {
			fmt.Fprintf(cmd.OutOrStdout(), "\tEC index: %d\n", object.ecHeader.index)
			fmt.Fprintf(cmd.OutOrStdout(), "\tEC parent: %s\n", object.ecHeader.parent.EncodeToString())
		}
		op, ok := result.placements[object.objectID]
		if !ok {
			continue
		}
		if len(op.requiredNodes) > 0 {
			fmt.Fprintf(cmd.OutOrStdout(), "\tRequired nodes:\n")
			for _, node := range op.requiredNodes {
				fmt.Fprintf(cmd.OutOrStdout(), "\t\t- %s\n", hex.EncodeToString(node.PublicKey()))
			}
		}
		if len(op.confirmedNodes) > 0 {
			fmt.Fprintf(cmd.OutOrStdout(), "\tConfirmed nodes:\n")
			for _, node := range op.confirmedNodes {
				fmt.Fprintf(cmd.OutOrStdout(), "\t\t- %s\n", hex.EncodeToString(node.PublicKey()))
			}
		}
	}

	if len(result.errors) == 0 {
		return
	}
	fmt.Fprintf(cmd.OutOrStdout(), "Errors:\n")
	for _, err := range result.errors {
		fmt.Fprintf(cmd.OutOrStdout(), "\t%s\n", err.Error())
	}
}

func printObjectNodesAsJSON(cmd *cobra.Command, objID oid.ID, objects []phyObject, result *objectNodesResult) {
	jsonResult := &objNodesResultJSON{
		ObjectID: objID.EncodeToString(),
	}

	for _, object := range objects {
		do := ObjNodesDataObject{
			ObjectID: object.objectID.EncodeToString(),
		}
		if object.ecHeader != nil {
			do.ECIndex = &object.ecHeader.index
			ecParent := object.ecHeader.parent.EncodeToString()
			do.ECParentObjectID = &ecParent
		}
		op, ok := result.placements[object.objectID]
		if !ok {
			continue
		}
		for _, rn := range op.requiredNodes {
			do.RequiredNodes = append(do.RequiredNodes, hex.EncodeToString(rn.PublicKey()))
		}
		for _, cn := range op.confirmedNodes {
			do.ConfirmedNodes = append(do.ConfirmedNodes, hex.EncodeToString(cn.PublicKey()))
		}
		jsonResult.DataObjects = append(jsonResult.DataObjects, do)
	}
	for _, err := range result.errors {
		jsonResult.Errors = append(jsonResult.Errors, err.Error())
	}
	b, err := json.Marshal(jsonResult)
	commonCmd.ExitOnErr(cmd, "failed to marshal json: %w", err)
	cmd.Println(string(b))
}