package storage

import (
	"bytes"
	"sort"
	"sync"
	"time"
)

// Compactor handles LSM-tree compaction
type Compactor struct {
	lsm      *LSMTree
	stopChan chan struct{}
	wg       sync.WaitGroup
}

// NewCompactor creates a new compactor
func NewCompactor(lsm *LSMTree) *Compactor {
	return &Compactor{
		lsm:      lsm,
		stopChan: make(chan struct{}),
	}
}

// Run starts background compaction
func (c *Compactor) Run() {
	c.wg.Add(1)
	go c.runLoop()
}

func (c *Compactor) runLoop() {
	defer c.wg.Done()

	ticker := time.NewTicker(10 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			c.checkAndCompact()
		case <-c.stopChan:
			return
		}
	}
}

// TriggerCompaction triggers compaction for a level
func (c *Compactor) TriggerCompaction(level int) {
	go c.compactLevel(level)
}

func (c *Compactor) checkAndCompact() {
	// Check each level for compaction threshold
	for i := 0; i < len(c.lsm.levels)-1; i++ {
		threshold := c.getLevelThreshold(i)
		c.lsm.levels[i].mu.RLock()
		needsCompaction := len(c.lsm.levels[i].sstables) > threshold
		c.lsm.levels[i].mu.RUnlock()

		if needsCompaction {
			c.TriggerCompaction(i)
		}
	}
}

func (c *Compactor) getLevelThreshold(level int) int {
	// L0: 4 files, L1: 10 files, L2: 100 files, etc.
	if level == 0 {
		return 4
	}
	threshold := 10
	for i := 1; i < level; i++ {
		threshold *= 10
	}
	return threshold
}

func (c *Compactor) compactLevel(level int) {
	c.lsm.levels[level].mu.Lock()
	if len(c.lsm.levels[level].sstables) == 0 {
		c.lsm.levels[level].mu.Unlock()
		return
	}

	// Select SSTables to compact (oldest ones)
	inputTables := c.lsm.levels[level].sstables
	c.lsm.levels[level].sstables = []*SSTable{}
	c.lsm.levels[level].mu.Unlock()

	// Merge sort all SSTables
	merged := c.mergeSSTables(inputTables)

	// Write merged entries to next level
	c.lsm.levels[level+1].mu.Lock()
	for _, sst := range merged {
		c.lsm.levels[level+1].sstables = append(c.lsm.levels[level+1].sstables, sst)
	}
	c.lsm.levels[level+1].mu.Unlock()

	// Clean up old SSTables
	for _, sst := range inputTables {
		sst.file.Close()
		// os.Remove(sst.path) // Uncomment to delete files
	}
}

func (c *Compactor) mergeSSTables(sstables []*SSTable) []*SSTable {
	// Collect all entries
	type entry struct {
		key   []byte
		value []byte
	}
	entries := []entry{}

	for _, sst := range sstables {
		sst.Scan(nil, nil, func(key, value []byte) {
			entries = append(entries, entry{key: key, value: value})
		})
	}

	// Sort by key
	sort.Slice(entries, func(i, j int) bool {
		return bytes.Compare(entries[i].key, entries[j].key) < 0
	})

	// Remove duplicates (keep newest)
	deduplicated := []entry{}
	for i, e := range entries {
		if i == 0 || !bytes.Equal(e.key, entries[i-1].key) {
			deduplicated = append(deduplicated, e)
		}
	}

	// Write to new SSTable
	if len(deduplicated) == 0 {
		return []*SSTable{}
	}

	writer, _ := NewSSTableWriter(
		c.lsm.config.DataDir+"/sstable/merged.sst",
		uint64(time.Now().UnixNano()),
		sstables[0].level+1,
		c.lsm.config.BloomFilterBits,
	)

	for _, e := range deduplicated {
		writer.Add(e.key, e.value)
	}

	merged, _ := writer.Finalize()
	return []*SSTable{merged}
}

// Stop stops the compactor
func (c *Compactor) Stop() {
	close(c.stopChan)
	c.wg.Wait()
}
