package erasurecode

import (
	"bytes"
	"crypto/ecdsa"
	"fmt"

	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"github.com/klauspost/reedsolomon"
)

// Reconstruct returns full object reconstructed from parts.
// All non-nil objects in parts must have EC header with the same `total` field equal to len(parts).
// The slice must contain at least one non nil object.
// Index of the objects in parts must be equal to it's index field in the EC header.
// The parts slice isn't changed and can be used concurrently for reading.
func (c *Constructor) Reconstruct(parts []*objectSDK.Object) (*objectSDK.Object, error) {
	res, err := c.ReconstructHeader(parts)
	if err != nil {
		return nil, err
	}

	c.fillPayload(parts)

	payload, err := reconstructExact(c.enc, int(res.PayloadSize()), c.payloadShards)
	if err != nil {
		return nil, fmt.Errorf("%w: %w", ErrMalformedSlice, err)
	}

	res.SetPayload(payload)
	return res, nil
}

// ReconstructHeader returns object header reconstructed from parts.
// All non-nil objects in parts must have EC header with the same `total` field equal to len(parts).
// The slice must contain at least one non nil object.
// Index of the objects in parts must be equal to it's index field in the EC header.
// The parts slice isn't changed and can be used concurrently for reading.
func (c *Constructor) ReconstructHeader(parts []*objectSDK.Object) (*objectSDK.Object, error) {
	c.clear()

	if err := c.fillHeader(parts); err != nil {
		return nil, err
	}

	obj, err := c.reconstructHeader()
	if err != nil {
		return nil, fmt.Errorf("%w: %w", ErrMalformedSlice, err)
	}
	return obj, nil
}

// ReconstructParts reconstructs specific EC parts without reconstructing full object.
// All non-nil objects in parts must have EC header with the same `total` field equal to len(parts).
// The slice must contain at least one non nil object.
// Index of the objects in parts must be equal to it's index field in the EC header.
// Those parts for which corresponding element in required is true must be nil and will be overwritten.
// Because partial reconstruction only makes sense for full objects, all parts must have non-empty payload.
// If key is not nil, all reconstructed parts are signed with this key.
func (c *Constructor) ReconstructParts(parts []*objectSDK.Object, required []bool, key *ecdsa.PrivateKey) error {
	if len(required) != len(parts) {
		return fmt.Errorf("len(parts) != len(required): %d != %d", len(parts), len(required))
	}

	c.clear()

	if err := c.fillHeader(parts); err != nil {
		return err
	}
	c.fillPayload(parts)

	if err := c.enc.ReconstructSome(c.payloadShards, required); err != nil {
		return fmt.Errorf("%w: %w", ErrMalformedSlice, err)
	}
	if err := c.enc.ReconstructSome(c.headerShards, required); err != nil {
		return fmt.Errorf("%w: %w", ErrMalformedSlice, err)
	}

	nonNilPart := 0
	for i := range parts {
		if parts[i] != nil {
			nonNilPart = i
			break
		}
	}

	ec := parts[nonNilPart].GetECHeader()
	ecParentInfo := objectSDK.ECParentInfo{
		ID:            ec.Parent(),
		SplitID:       ec.ParentSplitID(),
		SplitParentID: ec.ParentSplitParentID(),
		Attributes:    ec.ParentAttributes(),
	}
	total := ec.Total()

	for i := range required {
		if parts[i] != nil || !required[i] {
			continue
		}

		part := objectSDK.New()
		copyRequiredFields(part, parts[nonNilPart])
		part.SetPayload(c.payloadShards[i])
		part.SetPayloadSize(uint64(len(c.payloadShards[i])))
		part.SetECHeader(objectSDK.NewECHeader(ecParentInfo, uint32(i), total, c.headerShards[i], c.headerLength))

		if err := setIDWithSignature(part, key); err != nil {
			return err
		}
		parts[i] = part
	}
	return nil
}

func (c *Constructor) reconstructHeader() (*objectSDK.Object, error) {
	data, err := reconstructExact(c.enc, int(c.headerLength), c.headerShards)
	if err != nil {
		return nil, err
	}

	var obj objectSDK.Object
	return &obj, obj.Unmarshal(data)
}

func reconstructExact(enc reedsolomon.Encoder, size int, shards [][]byte) ([]byte, error) {
	if err := enc.ReconstructData(shards); err != nil {
		return nil, err
	}

	// Technically, this error will be returned from enc.Join().
	// However, allocating based on unvalidated user data is an easy attack vector.
	// Preallocating seems to have enough benefits to justify a slight increase in code complexity.
	maxSize := 0
	for i := range shards {
		maxSize += len(shards[i])
	}
	if size > maxSize {
		return nil, reedsolomon.ErrShortData
	}

	buf := bytes.NewBuffer(make([]byte, 0, size))
	if err := enc.Join(buf, shards, size); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}