package websocket

import (
	"context"
	"encoding/json"
	"log"
	"net/http"
	"sync"

	"github.com/gorilla/websocket"
	"github.com/redis/go-redis/v9"
)

var upgrader = websocket.Upgrader{
	CheckOrigin: func(r *http.Request) bool {
		return true // Allow all origins (configure properly in production)
	},
}

// WebSocketServer broadcasts real-time updates to connected clients
type WebSocketServer struct {
	clients    map[*Client]bool
	broadcast  chan interface{}
	register   chan *Client
	unregister chan *Client
	redis      *redis.Client
	mu         sync.RWMutex
}

type Client struct {
	ws     *websocket.Conn
	send   chan interface{}
	topics []string // Subscribed topics
}

func NewWebSocketServer(redisClient *redis.Client) *WebSocketServer {
	return &WebSocketServer{
		clients:    make(map[*Client]bool),
		broadcast:  make(chan interface{}, 1000),
		register:   make(chan *Client),
		unregister: make(chan *Client),
		redis:      redisClient,
	}
}

// Run starts the WebSocket server
func (wss *WebSocketServer) Run(ctx context.Context) {
	// Subscribe to Redis pub/sub
	go wss.subscribeRedis(ctx)

	for {
		select {
		case <-ctx.Done():
			return

		case client := <-wss.register:
			wss.mu.Lock()
			wss.clients[client] = true
			wss.mu.Unlock()
			log.Printf("Client connected. Total: %d", len(wss.clients))

		case client := <-wss.unregister:
			wss.mu.Lock()
			if _, ok := wss.clients[client]; ok {
				delete(wss.clients, client)
				close(client.send)
			}
			wss.mu.Unlock()
			log.Printf("Client disconnected. Total: %d", len(wss.clients))

		case message := <-wss.broadcast:
			// Broadcast to all clients
			wss.mu.RLock()
			for client := range wss.clients {
				select {
				case client.send <- message:
				default:
					close(client.send)
					delete(wss.clients, client)
				}
			}
			wss.mu.RUnlock()
		}
	}
}

// Subscribe to Redis pub/sub for real-time updates
func (wss *WebSocketServer) subscribeRedis(ctx context.Context) {
	pubsub := wss.redis.Subscribe(ctx, "dashboard:*")
	defer pubsub.Close()

	for {
		select {
		case <-ctx.Done():
			return
		case msg := <-pubsub.Channel():
			var data interface{}
			if err := json.Unmarshal([]byte(msg.Payload), &data); err != nil {
				log.Printf("Error unmarshaling message: %v", err)
				continue
			}

			// Broadcast to WebSocket clients
			wss.broadcast <- data
		}
	}
}

// HandleWebSocket upgrades HTTP to WebSocket
func (wss *WebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
	ws, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Printf("WebSocket upgrade error: %v", err)
		return
	}

	client := &Client{
		ws:     ws,
		send:   make(chan interface{}, 256),
		topics: []string{},
	}

	wss.register <- client

	// Start goroutines for reading and writing
	go client.writePump()
	go client.readPump(wss)
}

// Publish message to broadcast channel
func (wss *WebSocketServer) Publish(message interface{}) {
	select {
	case wss.broadcast <- message:
	default:
		log.Printf("Broadcast channel full, dropping message")
	}
}

// Write messages to WebSocket
func (c *Client) writePump() {
	defer c.ws.Close()

	for message := range c.send {
		data, err := json.Marshal(message)
		if err != nil {
			log.Printf("JSON marshal error: %v", err)
			continue
		}

		if err := c.ws.WriteMessage(websocket.TextMessage, data); err != nil {
			log.Printf("WebSocket write error: %v", err)
			return
		}
	}
}

// Read messages from WebSocket (subscriptions, etc.)
func (c *Client) readPump(wss *WebSocketServer) {
	defer func() {
		wss.unregister <- c
		c.ws.Close()
	}()

	for {
		_, message, err := c.ws.ReadMessage()
		if err != nil {
			break
		}

		// Handle subscription requests
		var req struct {
			Action string   `json:"action"`
			Topics []string `json:"topics"`
		}

		if err := json.Unmarshal(message, &req); err != nil {
			continue
		}

		if req.Action == "subscribe" {
			c.topics = append(c.topics, req.Topics...)
			log.Printf("Client subscribed to topics: %v", req.Topics)
		}
	}
}
