package handlers

import (
	"encoding/json"
	"log"
	"net/http"
	"time"

	"github.com/google/uuid"
	"github.com/gorilla/websocket"
	"github.com/yourusername/chatserver/internal/hub"
	"github.com/yourusername/chatserver/internal/models"
	"github.com/yourusername/chatserver/internal/repository"
	"go.mongodb.org/mongo-driver/bson/primitive"
)

var upgrader = websocket.Upgrader{
	ReadBufferSize:  1024,
	WriteBufferSize: 1024,
	CheckOrigin: func(r *http.Request) bool {
		return true // Configure properly in production
	},
}

type WebSocketHandler struct {
	hub     *hub.Hub
	msgRepo repository.MessageRepository
}

func NewWebSocketHandler(h *hub.Hub, msgRepo repository.MessageRepository) *WebSocketHandler {
	return &WebSocketHandler{
		hub:     h,
		msgRepo: msgRepo,
	}
}

func (h *WebSocketHandler) ServeWS(w http.ResponseWriter, r *http.Request) {
	conn, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Printf("WebSocket upgrade error: %v", err)
		return
	}

	// Get user and room from query params or context
	userID := r.URL.Query().Get("user_id")
	roomID := r.URL.Query().Get("room_id")

	if userID == "" || roomID == "" {
		conn.Close()
		return
	}

	client := &hub.Client{
		ID:     uuid.New().String(),
		UserID: userID,
		RoomID: roomID,
		Conn:   conn,
		Send:   make(chan []byte, 256),
	}

	h.hub.Register(client)

	go h.writePump(client)
	go h.readPump(client)
}

func (h *WebSocketHandler) readPump(client *hub.Client) {
	defer func() {
		h.hub.Unregister(client)
		client.Conn.Close()
	}()

	client.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
	client.Conn.SetPongHandler(func(string) error {
		client.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
		return nil
	})

	for {
		_, msgData, err := client.Conn.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
				log.Printf("WebSocket error: %v", err)
			}
			break
		}

		var clientMsg models.ClientMessage
		if err := json.Unmarshal(msgData, &clientMsg); err != nil {
			log.Printf("Error unmarshaling message: %v", err)
			continue
		}

		roomID, _ := primitive.ObjectIDFromHex(client.RoomID)
		userID, _ := primitive.ObjectIDFromHex(client.UserID)

		message := &models.Message{
			RoomID:    roomID,
			UserID:    userID,
			Content:   clientMsg.Content,
			Type:      models.MessageTypeText,
			Timestamp: time.Now(),
		}

		// Save to database
		if err := h.msgRepo.Create(message); err != nil {
			log.Printf("Error saving message: %v", err)
		}

		// Broadcast to room
		h.hub.Broadcast(message)
	}
}

func (h *WebSocketHandler) writePump(client *hub.Client) {
	ticker := time.NewTicker(54 * time.Second)
	defer func() {
		ticker.Stop()
		client.Conn.Close()
	}()

	for {
		select {
		case message, ok := <-client.Send:
			client.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
			if !ok {
				client.Conn.WriteMessage(websocket.CloseMessage, []byte{})
				return
			}

			w, err := client.Conn.NextWriter(websocket.TextMessage)
			if err != nil {
				return
			}
			w.Write(message)

			if err := w.Close(); err != nil {
				return
			}

		case <-ticker.C:
			client.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
			if err := client.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
				return
			}
		}
	}
}
