package p2p

import (
	"blockchain-network/pkg/blockchain"
	"blockchain-network/pkg/mempool"
	"encoding/json"
	"fmt"
	"net"
	"sync"
)

// Message types
const (
	MsgBlock     = "block"
	MsgTx        = "transaction"
	MsgGetBlocks = "getblocks"
	MsgPeerList  = "peerlist"
)

// Message represents a P2P message
type Message struct {
	Type    string          `json:"type"`
	Payload json.RawMessage `json:"payload"`
}

// Peer represents a network peer
type Peer struct {
	Address string
	Conn    net.Conn
}

// Network manages P2P connections
type Network struct {
	port       int
	peers      map[string]*Peer
	blockchain *blockchain.Blockchain
	mempool    *mempool.Mempool
	mu         sync.RWMutex
	listener   net.Listener
}

// NewNetwork creates a new P2P network
func NewNetwork(port int, bc *blockchain.Blockchain, mp *mempool.Mempool) *Network {
	return &Network{
		port:       port,
		peers:      make(map[string]*Peer),
		blockchain: bc,
		mempool:    mp,
	}
}

// Start begins listening for connections
func (n *Network) Start() error {
	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", n.port))
	if err != nil {
		return err
	}

	n.listener = listener

	go n.acceptConnections()

	fmt.Printf("P2P network listening on port %d\n", n.port)
	return nil
}

// acceptConnections handles incoming peer connections
func (n *Network) acceptConnections() {
	for {
		conn, err := n.listener.Accept()
		if err != nil {
			fmt.Printf("Accept error: %v\n", err)
			continue
		}

		go n.handlePeer(conn)
	}
}

// handlePeer processes messages from a peer
func (n *Network) handlePeer(conn net.Conn) {
	defer conn.Close()

	peerAddr := conn.RemoteAddr().String()
	fmt.Printf("New peer connected: %s\n", peerAddr)

	n.mu.Lock()
	n.peers[peerAddr] = &Peer{
		Address: peerAddr,
		Conn:    conn,
	}
	n.mu.Unlock()

	defer func() {
		n.mu.Lock()
		delete(n.peers, peerAddr)
		n.mu.Unlock()
	}()

	decoder := json.NewDecoder(conn)

	for {
		var msg Message
		if err := decoder.Decode(&msg); err != nil {
			fmt.Printf("Decode error from %s: %v\n", peerAddr, err)
			return
		}

		if err := n.handleMessage(msg, peerAddr); err != nil {
			fmt.Printf("Handle message error: %v\n", err)
		}
	}
}

// handleMessage processes received messages
func (n *Network) handleMessage(msg Message, from string) error {
	switch msg.Type {
	case MsgBlock:
		var block blockchain.Block
		if err := json.Unmarshal(msg.Payload, &block); err != nil {
			return err
		}

		// Validate and add block
		if err := n.blockchain.AddBlock(&block); err != nil {
			return fmt.Errorf("invalid block: %v", err)
		}

		fmt.Printf("Received new block at height %d from %s\n",
			block.Header.Height, from)

		// Broadcast to other peers (except sender)
		n.BroadcastBlock(&block, from)

	case MsgTx:
		var tx blockchain.Transaction
		if err := json.Unmarshal(msg.Payload, &tx); err != nil {
			return err
		}

		// Add to mempool
		if err := n.mempool.AddTransaction(&tx); err != nil {
			return fmt.Errorf("invalid transaction: %v", err)
		}

		fmt.Printf("Received new transaction from %s\n", from)

		// Broadcast to other peers
		n.BroadcastTransaction(&tx, from)

	case MsgGetBlocks:
		// Send blocks to peer
		n.sendBlocks(from)

	case MsgPeerList:
		var peerAddrs []string
		if err := json.Unmarshal(msg.Payload, &peerAddrs); err != nil {
			return err
		}

		// Connect to new peers
		for _, addr := range peerAddrs {
			go n.ConnectToPeer(addr)
		}
	}

	return nil
}

// BroadcastBlock sends block to all peers except excludePeer
func (n *Network) BroadcastBlock(block *blockchain.Block, excludePeer string) {
	payload, _ := json.Marshal(block)
	msg := Message{
		Type:    MsgBlock,
		Payload: payload,
	}

	n.mu.RLock()
	defer n.mu.RUnlock()

	for addr, peer := range n.peers {
		if addr == excludePeer {
			continue
		}

		encoder := json.NewEncoder(peer.Conn)
		encoder.Encode(msg)
	}
}

// BroadcastTransaction sends transaction to all peers
func (n *Network) BroadcastTransaction(tx *blockchain.Transaction, excludePeer string) {
	payload, _ := json.Marshal(tx)
	msg := Message{
		Type:    MsgTx,
		Payload: payload,
	}

	n.mu.RLock()
	defer n.mu.RUnlock()

	for addr, peer := range n.peers {
		if addr == excludePeer {
			continue
		}

		encoder := json.NewEncoder(peer.Conn)
		encoder.Encode(msg)
	}
}

// ConnectToPeer establishes connection to a peer
func (n *Network) ConnectToPeer(address string) error {
	// Check if already connected
	n.mu.RLock()
	if _, exists := n.peers[address]; exists {
		n.mu.RUnlock()
		return nil
	}
	n.mu.RUnlock()

	conn, err := net.Dial("tcp", address)
	if err != nil {
		return err
	}

	n.mu.Lock()
	n.peers[address] = &Peer{
		Address: address,
		Conn:    conn,
	}
	n.mu.Unlock()

	go n.handlePeer(conn)

	// Request blocks for sync
	msg := Message{Type: MsgGetBlocks}
	encoder := json.NewEncoder(conn)
	encoder.Encode(msg)

	return nil
}

// sendBlocks sends blockchain to peer
func (n *Network) sendBlocks(peerAddr string) {
	n.mu.RLock()
	peer, exists := n.peers[peerAddr]
	n.mu.RUnlock()

	if !exists {
		return
	}

	blocks := n.blockchain.GetBlocks()
	for _, block := range blocks {
		payload, _ := json.Marshal(block)
		msg := Message{
			Type:    MsgBlock,
			Payload: payload,
		}

		encoder := json.NewEncoder(peer.Conn)
		encoder.Encode(msg)
	}
}

// GetPeerCount returns number of connected peers
func (n *Network) GetPeerCount() int {
	n.mu.RLock()
	defer n.mu.RUnlock()
	return len(n.peers)
}
