package marshal

import (
	"encoding/binary"
	"fmt"
)

const (
	Version byte = 0 // increase if breaking change

	ByteSize  int = 1
	UInt8Size int = ByteSize
	BoolSize  int = ByteSize

	nilSlice     int64 = -1
	nilSliceSize int   = 1

	byteTrue  uint8 = 1
	byteFalse uint8 = 0

	// maxSliceLen taken from https://github.com/neo-project/neo/blob/38218bbee5bbe8b33cd8f9453465a19381c9a547/src/Neo/IO/Helper.cs#L77
	maxSliceLen = 0x1000000
)

type MarshallerError struct {
	errMsg string
	offset int
}

func (e *MarshallerError) Error() string {
	if e == nil {
		return ""
	}
	if e.offset < 0 {
		return e.errMsg
	}
	return fmt.Sprintf("%s (offset: %d)", e.errMsg, e.offset)
}

func errBufTooSmall(t string, marshal bool, offset int) error {
	action := "unmarshal"
	if marshal {
		action = "marshal"
	}
	return &MarshallerError{
		errMsg: fmt.Sprintf("not enough bytes left to %s value of type '%s'", action, t),
		offset: offset,
	}
}

func VerifyMarshal(buf []byte, lastOffset int) error {
	if len(buf) != lastOffset {
		return &MarshallerError{
			errMsg: "actual data size differs from expected",
			offset: -1,
		}
	}
	return nil
}

func VerifyUnmarshal(buf []byte, lastOffset int) error {
	if len(buf) != lastOffset {
		return &MarshallerError{
			errMsg: "unmarshalled bytes left",
		}
	}
	return nil
}

func SliceSize[T any](slice []T, sizeOf func(T) int) int {
	if slice == nil {
		return nilSliceSize
	}
	s := Int64Size(int64(len(slice)))
	for _, v := range slice {
		s += sizeOf(v)
	}
	return s
}

func SliceMarshal[T any](buf []byte, offset int, slice []T, marshalT func([]byte, int, T) (int, error)) (int, error) {
	if slice == nil {
		return Int64Marshal(buf, offset, nilSlice)
	}
	if len(slice) > maxSliceLen {
		return 0, &MarshallerError{
			errMsg: fmt.Sprintf("slice size if too big: '%d'", len(slice)),
			offset: offset,
		}
	}
	offset, err := Int64Marshal(buf, offset, int64(len(slice)))
	if err != nil {
		return 0, err
	}
	for _, v := range slice {
		offset, err = marshalT(buf, offset, v)
		if err != nil {
			return 0, err
		}
	}
	return offset, nil
}

func SliceUnmarshal[T any](buf []byte, offset int, unmarshalT func(buf []byte, offset int) (T, int, error)) ([]T, int, error) {
	size, offset, err := Int64Unmarshal(buf, offset)
	if err != nil {
		return nil, 0, err
	}
	if size == nilSlice {
		return nil, offset, nil
	}
	if size > maxSliceLen {
		return nil, 0, &MarshallerError{
			errMsg: fmt.Sprintf("slice size if too big: '%d'", size),
			offset: offset,
		}
	}
	if size < 0 {
		return nil, 0, &MarshallerError{
			errMsg: fmt.Sprintf("invalid slice size: '%d'", size),
			offset: offset,
		}
	}
	result := make([]T, size)
	for idx := 0; idx < len(result); idx++ {
		result[idx], offset, err = unmarshalT(buf, offset)
		if err != nil {
			return nil, 0, err
		}
	}
	return result, offset, nil
}

func Int64Size(v int64) int {
	// https://cs.opensource.google/go/go/+/master:src/encoding/binary/varint.go;l=92;drc=dac9b9ddbd5160c5f4552410f5f8281bd5eed38c
	// and
	// https://cs.opensource.google/go/go/+/master:src/encoding/binary/varint.go;l=41;drc=dac9b9ddbd5160c5f4552410f5f8281bd5eed38c
	ux := uint64(v) << 1
	if v < 0 {
		ux = ^ux
	}
	s := 0
	for ux >= 0x80 {
		s++
		ux >>= 7
	}
	return s + 1
}

func Int64Marshal(buf []byte, offset int, v int64) (int, error) {
	if len(buf)-offset < Int64Size(v) {
		return 0, errBufTooSmall("int64", true, offset)
	}
	return offset + binary.PutVarint(buf[offset:], v), nil
}

func Int64Unmarshal(buf []byte, offset int) (int64, int, error) {
	v, read := binary.Varint(buf[offset:])
	if read == 0 {
		return 0, 0, errBufTooSmall("int64", false, offset)
	}
	if read < 0 {
		return 0, 0, &MarshallerError{
			errMsg: "int64 unmarshal overflow",
			offset: offset,
		}
	}
	return v, offset + read, nil
}

func StringSize(s string) int {
	return Int64Size(int64(len(s))) + len(s)
}

func StringMarshal(buf []byte, offset int, s string) (int, error) {
	if len(s) > maxSliceLen {
		return 0, &MarshallerError{
			errMsg: fmt.Sprintf("string is too long: '%d'", len(s)),
			offset: offset,
		}
	}
	if len(buf)-offset < Int64Size(int64(len(s)))+len(s) {
		return 0, errBufTooSmall("string", true, offset)
	}

	offset, err := Int64Marshal(buf, offset, int64(len(s)))
	if err != nil {
		return 0, err
	}
	if s == "" {
		return offset, nil
	}
	return offset + copy(buf[offset:], s), nil
}

func StringUnmarshal(buf []byte, offset int) (string, int, error) {
	size, offset, err := Int64Unmarshal(buf, offset)
	if err != nil {
		return "", 0, err
	}
	if size == 0 {
		return "", offset, nil
	}
	if size > maxSliceLen {
		return "", 0, &MarshallerError{
			errMsg: fmt.Sprintf("string is too long: '%d'", size),
			offset: offset,
		}
	}
	if size < 0 {
		return "", 0, &MarshallerError{
			errMsg: fmt.Sprintf("invalid string size: '%d'", size),
			offset: offset,
		}
	}
	if len(buf)-offset < int(size) {
		return "", 0, errBufTooSmall("string", false, offset)
	}
	return string(buf[offset : offset+int(size)]), offset + int(size), nil
}

func UInt8Marshal(buf []byte, offset int, value uint8) (int, error) {
	if len(buf)-offset < 1 {
		return 0, errBufTooSmall("uint8", true, offset)
	}
	buf[offset] = value
	return offset + 1, nil
}

func UInt8Unmarshal(buf []byte, offset int) (uint8, int, error) {
	if len(buf)-offset < 1 {
		return 0, 0, errBufTooSmall("uint8", false, offset)
	}
	return buf[offset], offset + 1, nil
}

func ByteMarshal(buf []byte, offset int, value byte) (int, error) {
	return UInt8Marshal(buf, offset, value)
}

func ByteUnmarshal(buf []byte, offset int) (byte, int, error) {
	return UInt8Unmarshal(buf, offset)
}

func BoolMarshal(buf []byte, offset int, value bool) (int, error) {
	if value {
		return UInt8Marshal(buf, offset, byteTrue)
	}
	return UInt8Marshal(buf, offset, byteFalse)
}

func BoolUnmarshal(buf []byte, offset int) (bool, int, error) {
	v, offset, err := UInt8Unmarshal(buf, offset)
	if err != nil {
		return false, 0, err
	}
	if v == byteTrue {
		return true, offset, nil
	}
	if v == byteFalse {
		return false, offset, nil
	}
	return false, 0, &MarshallerError{
		errMsg: fmt.Sprintf("invalid marshalled value for bool: %d", v),
		offset: offset - BoolSize,
	}
}