package getsvc

import (
	"context"
	"crypto/sha256"
	"errors"
	"hash"

	objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/refs"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/status"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/client"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
	objectSvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object"
	getsvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/get"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/util"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	versionSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/version"
	"git.frostfs.info/TrueCloudLab/tzhash/tz"
)

func (s *Service) toPrm(req *objectV2.GetRequest, stream objectSvc.GetObjectStream) (*getsvc.Prm, error) {
	body := req.GetBody()

	addrV2 := body.GetAddress()
	if addrV2 == nil {
		return nil, errMissingObjAddress
	}

	var addr oid.Address

	err := addr.ReadFromV2(*addrV2)
	if err != nil {
		return nil, errInvalidObjAddress(err)
	}

	commonPrm, err := util.CommonPrmFromV2(req)
	if err != nil {
		return nil, err
	}

	streamWrapper := &streamObjectWriter{stream}

	p := new(getsvc.Prm)
	p.SetCommonParameters(commonPrm)

	p.WithAddress(addr)
	p.WithRawFlag(body.GetRaw())
	p.SetObjectWriter(streamWrapper)

	if !commonPrm.LocalOnly() {
		key, err := s.keyStorage.GetKey(nil)
		if err != nil {
			return nil, err
		}

		forwarder := &getRequestForwarder{
			GlobalProgress: 0,
			Key:            key,
			Request:        req,
			Stream:         streamWrapper,
		}

		p.SetRequestForwarder(groupAddressRequestForwarder(forwarder.forwardRequestToNode))
	}

	return p, nil
}

func (s *Service) toRangePrm(req *objectV2.GetRangeRequest, stream objectSvc.GetObjectRangeStream) (*getsvc.RangePrm, error) {
	body := req.GetBody()

	addrV2 := body.GetAddress()
	if addrV2 == nil {
		return nil, errMissingObjAddress
	}

	var addr oid.Address

	err := addr.ReadFromV2(*addrV2)
	if err != nil {
		return nil, errInvalidObjAddress(err)
	}

	commonPrm, err := util.CommonPrmFromV2(req)
	if err != nil {
		return nil, err
	}

	p := new(getsvc.RangePrm)
	p.SetCommonParameters(commonPrm)

	streamWrapper := &streamObjectRangeWriter{stream}

	p.WithAddress(addr)
	p.WithRawFlag(body.GetRaw())
	p.SetChunkWriter(streamWrapper)
	p.SetRange(object.NewRangeFromV2(body.GetRange()))

	err = p.Validate()
	if err != nil {
		return nil, errRequestParamsValidation(err)
	}

	if !commonPrm.LocalOnly() {
		key, err := s.keyStorage.GetKey(nil)
		if err != nil {
			return nil, err
		}

		forwarder := &getRangeRequestForwarder{
			GlobalProgress: 0,
			Key:            key,
			Request:        req,
			Stream:         streamWrapper,
		}

		p.SetRequestForwarder(groupAddressRequestForwarder(forwarder.forwardRequestToNode))
	}

	return p, nil
}

func (s *Service) toHashRangePrm(req *objectV2.GetRangeHashRequest) (*getsvc.RangeHashPrm, error) {
	body := req.GetBody()

	addrV2 := body.GetAddress()
	if addrV2 == nil {
		return nil, errMissingObjAddress
	}

	var addr oid.Address

	err := addr.ReadFromV2(*addrV2)
	if err != nil {
		return nil, errInvalidObjAddress(err)
	}

	commonPrm, err := util.CommonPrmFromV2(req)
	if err != nil {
		return nil, err
	}

	p := new(getsvc.RangeHashPrm)
	p.SetCommonParameters(commonPrm)

	p.WithAddress(addr)

	if tok := commonPrm.SessionToken(); tok != nil {
		signerKey, err := s.keyStorage.GetKey(&util.SessionInfo{
			ID:    tok.ID(),
			Owner: tok.Issuer(),
		})
		if err != nil && errors.As(err, new(apistatus.SessionTokenNotFound)) {
			commonPrm.ForgetTokens()
			signerKey, err = s.keyStorage.GetKey(nil)
		}

		if err != nil {
			return nil, errFetchingSessionKey(err)
		}

		p.WithCachedSignerKey(signerKey)
	}

	rngsV2 := body.GetRanges()
	rngs := make([]object.Range, len(rngsV2))

	for i := range rngsV2 {
		rngs[i] = *object.NewRangeFromV2(&rngsV2[i])
	}

	p.SetRangeList(rngs)
	p.SetSalt(body.GetSalt())

	switch t := body.GetType(); t {
	default:
		return nil, errUnknownChechsumType(t)
	case refs.SHA256:
		p.SetHashGenerator(func() hash.Hash {
			return sha256.New()
		})
	case refs.TillichZemor:
		p.SetHashGenerator(func() hash.Hash {
			return tz.New()
		})
	}

	return p, nil
}

type headResponseWriter struct {
	mainOnly bool

	body *objectV2.HeadResponseBody
}

func (w *headResponseWriter) WriteHeader(_ context.Context, hdr *object.Object) error {
	if w.mainOnly {
		w.body.SetHeaderPart(toShortObjectHeader(hdr))
	} else {
		w.body.SetHeaderPart(toFullObjectHeader(hdr))
	}

	return nil
}

func (s *Service) toHeadPrm(req *objectV2.HeadRequest, resp *objectV2.HeadResponse) (*getsvc.HeadPrm, error) {
	body := req.GetBody()

	addrV2 := body.GetAddress()
	if addrV2 == nil {
		return nil, errMissingObjAddress
	}

	var objAddr oid.Address

	err := objAddr.ReadFromV2(*addrV2)
	if err != nil {
		return nil, errInvalidObjAddress(err)
	}

	commonPrm, err := util.CommonPrmFromV2(req)
	if err != nil {
		return nil, err
	}

	p := new(getsvc.HeadPrm)
	p.SetCommonParameters(commonPrm)

	p.WithAddress(objAddr)
	p.WithRawFlag(body.GetRaw())
	p.SetHeaderWriter(&headResponseWriter{
		mainOnly: body.GetMainOnly(),
		body:     resp.GetBody(),
	})

	if commonPrm.LocalOnly() {
		return p, nil
	}

	key, err := s.keyStorage.GetKey(nil)
	if err != nil {
		return nil, err
	}

	forwarder := &headRequestForwarder{
		Request:    req,
		Response:   resp,
		ObjectAddr: objAddr,
		Key:        key,
	}

	p.SetRequestForwarder(groupAddressRequestForwarder(forwarder.forwardRequestToNode))

	return p, nil
}

func splitInfoResponse(info *object.SplitInfo) *objectV2.GetResponse {
	resp := new(objectV2.GetResponse)

	body := new(objectV2.GetResponseBody)
	resp.SetBody(body)

	body.SetObjectPart(info.ToV2())

	return resp
}

func splitInfoRangeResponse(info *object.SplitInfo) *objectV2.GetRangeResponse {
	resp := new(objectV2.GetRangeResponse)

	body := new(objectV2.GetRangeResponseBody)
	resp.SetBody(body)

	body.SetRangePart(info.ToV2())

	return resp
}

func setSplitInfoHeadResponse(info *object.SplitInfo, resp *objectV2.HeadResponse) {
	resp.GetBody().SetHeaderPart(info.ToV2())
}

func toHashResponse(typ refs.ChecksumType, res *getsvc.RangeHashRes) *objectV2.GetRangeHashResponse {
	resp := new(objectV2.GetRangeHashResponse)

	body := new(objectV2.GetRangeHashResponseBody)
	resp.SetBody(body)

	body.SetType(typ)
	body.SetHashList(res.Hashes())

	return resp
}

func toFullObjectHeader(hdr *object.Object) objectV2.GetHeaderPart {
	obj := hdr.ToV2()

	hs := new(objectV2.HeaderWithSignature)
	hs.SetHeader(obj.GetHeader())
	hs.SetSignature(obj.GetSignature())

	return hs
}

func toShortObjectHeader(hdr *object.Object) objectV2.GetHeaderPart {
	hdrV2 := hdr.ToV2().GetHeader()

	sh := new(objectV2.ShortHeader)
	sh.SetOwnerID(hdrV2.GetOwnerID())
	sh.SetCreationEpoch(hdrV2.GetCreationEpoch())
	sh.SetPayloadLength(hdrV2.GetPayloadLength())
	sh.SetVersion(hdrV2.GetVersion())
	sh.SetObjectType(hdrV2.GetObjectType())
	sh.SetHomomorphicHash(hdrV2.GetHomomorphicHash())
	sh.SetPayloadHash(hdrV2.GetPayloadHash())

	return sh
}

func groupAddressRequestForwarder(f func(context.Context, network.Address, client.MultiAddressClient, []byte) (*object.Object, error)) getsvc.RequestForwarder {
	return func(ctx context.Context, info client.NodeInfo, c client.MultiAddressClient) (*object.Object, error) {
		var (
			firstErr error
			res      *object.Object

			key = info.PublicKey()
		)

		info.AddressGroup().IterateAddresses(func(addr network.Address) (stop bool) {
			var err error

			defer func() {
				stop = err == nil

				if stop || firstErr == nil {
					firstErr = err
				}

				// would be nice to log otherwise
			}()

			res, err = f(ctx, addr, c, key)

			return
		})

		return res, firstErr
	}
}

func writeCurrentVersion(metaHdr *session.RequestMetaHeader) {
	versionV2 := new(refs.Version)

	apiVersion := versionSDK.Current()
	apiVersion.WriteToV2(versionV2)

	metaHdr.SetVersion(versionV2)
}

func checkStatus(stV2 *status.Status) error {
	if !status.IsSuccess(stV2.Code()) {
		st := apistatus.FromStatusV2(stV2)
		return apistatus.ErrFromStatus(st)
	}

	return nil
}

func chunkToSend(global, local int, chunk []byte) []byte {
	if global == local {
		return chunk
	}

	if local+len(chunk) <= global {
		// chunk has already been sent
		return nil
	}

	return chunk[global-local:]
}