package network

import (
	"errors"
	"fmt"
	"net"
	"sync"

	"github.com/CityOfZion/neo-go/pkg/io"
	"github.com/CityOfZion/neo-go/pkg/network/payload"
)

type handShakeStage uint8

//go:generate stringer -type=handShakeStage
const (
	nothingDone     handShakeStage = 0
	versionSent     handShakeStage = 1
	versionReceived handShakeStage = 2
	verAckSent      handShakeStage = 3
	verAckReceived  handShakeStage = 4
)

var (
	errStateMismatch = errors.New("tried to send protocol message before handshake completed")
)

// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
	// underlying TCP connection.
	conn net.Conn
	addr net.TCPAddr

	// The version of the peer.
	version *payload.Version

	handShake handShakeStage

	done chan error

	wg sync.WaitGroup
}

// NewTCPPeer returns a TCPPeer structure based on the given connection.
func NewTCPPeer(conn net.Conn) *TCPPeer {
	raddr := conn.RemoteAddr()
	// can't fail because raddr is a real connection
	tcpaddr, _ := net.ResolveTCPAddr(raddr.Network(), raddr.String())
	return &TCPPeer{
		conn: conn,
		done: make(chan error, 1),
		addr: *tcpaddr,
	}
}

// WriteMsg implements the Peer interface. This will write/encode the message
// to the underlying connection, this only works for messages other than Version
// or VerAck.
func (p *TCPPeer) WriteMsg(msg *Message) error {
	if !p.Handshaked() {
		return errStateMismatch
	}
	return p.writeMsg(msg)
}

func (p *TCPPeer) writeMsg(msg *Message) error {
	select {
	case err := <-p.done:
		return err
	default:
		w := io.NewBinWriterFromIO(p.conn)
		return msg.Encode(w)
	}
}

// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
	return p.handShake == verAckReceived
}

// SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error {
	if p.handShake != nothingDone {
		return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
	}
	err := p.writeMsg(msg)
	if err == nil {
		p.handShake = versionSent
	}
	return err
}

// HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
	if p.handShake != versionSent {
		return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
	}
	p.version = version
	p.handShake = versionReceived
	return nil
}

// SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error {
	if p.handShake != versionReceived {
		return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
	}
	err := p.writeMsg(msg)
	if err == nil {
		p.handShake = verAckSent
	}
	return err
}

// HandleVersionAck checks handshake sequence correctness when VerAck message
// is received.
func (p *TCPPeer) HandleVersionAck() error {
	if p.handShake != verAckSent {
		return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
	}
	p.handShake = verAckReceived
	return nil
}

// NetAddr implements the Peer interface.
func (p *TCPPeer) NetAddr() *net.TCPAddr {
	return &p.addr
}

// Done implements the Peer interface and notifies
// all other resources operating on it that this peer
// is no longer running.
func (p *TCPPeer) Done() chan error {
	return p.done
}

// Disconnect will fill the peer's done channel with the given error.
func (p *TCPPeer) Disconnect(err error) {
	p.conn.Close()
	select {
	case p.done <- err:
		// one message to the queue
	default:
		// the other side may already be gone, it's OK
	}
}

// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
	return p.version
}