package storage

import (
	"errors"
	"os"
	"path/filepath"
	"sync"
	"time"
)

var (
	ErrKeyNotFound = errors.New("key not found")
	ErrClosed      = errors.New("storage engine closed")
)

// LSMConfig holds configuration for LSM-tree
type LSMConfig struct {
	DataDir          string
	MemTableSize     int64  // Max size before flush (default 64 MB)
	BlockCacheSize   int64  // Block cache size (default 1 GB)
	BloomFilterBits  int    // Bits per key for bloom filter
	Compression      string // "snappy" or "zstd"
	MaxLevels        int    // Number of LSM levels (default 7)
	CompactionWorkers int   // Number of compaction workers
}

// DefaultLSMConfig returns default configuration
func DefaultLSMConfig() *LSMConfig {
	return &LSMConfig{
		DataDir:          "/tmp/lsm-data",
		MemTableSize:     64 * 1024 * 1024, // 64 MB
		BlockCacheSize:   1024 * 1024 * 1024, // 1 GB
		BloomFilterBits:  10,
		Compression:      "snappy",
		MaxLevels:        7,
		CompactionWorkers: 2,
	}
}

// LSMTree implements a log-structured merge tree
type LSMTree struct {
	config     *LSMConfig
	memTable   *MemTable
	immutables []*MemTable // Immutable MemTables being flushed
	levels     []*Level    // L0, L1, ..., L6
	wal        *WriteAheadLog
	blockCache *BlockCache
	compactor  *Compactor

	mu     sync.RWMutex
	closed bool
	wg     sync.WaitGroup
}

// Level represents a level in the LSM-tree
type Level struct {
	level    int
	sstables []*SSTable
	mu       sync.RWMutex
}

// NewLSMTree creates a new LSM-tree storage engine
func NewLSMTree(config *LSMConfig) (*LSMTree, error) {
	if err := os.MkdirAll(config.DataDir, 0755); err != nil {
		return nil, err
	}

	// Initialize WAL
	walPath := filepath.Join(config.DataDir, "wal")
	wal, err := NewWriteAheadLog(walPath)
	if err != nil {
		return nil, err
	}

	// Initialize levels
	levels := make([]*Level, config.MaxLevels)
	for i := 0; i < config.MaxLevels; i++ {
		levels[i] = &Level{level: i, sstables: []*SSTable{}}
	}

	lsm := &LSMTree{
		config:     config,
		memTable:   NewMemTable(),
		immutables: []*MemTable{},
		levels:     levels,
		wal:        wal,
		blockCache: NewBlockCache(config.BlockCacheSize),
		closed:     false,
	}

	lsm.compactor = NewCompactor(lsm)

	// Recover from WAL if exists
	if err := lsm.recoverFromWAL(); err != nil {
		return nil, err
	}

	return lsm, nil
}

// Put writes a key-value pair
func (lsm *LSMTree) Put(key, value []byte) error {
	lsm.mu.Lock()
	defer lsm.mu.Unlock()

	if lsm.closed {
		return ErrClosed
	}

	// Write to WAL first (durability)
	if err := lsm.wal.Append(key, value, OpPut); err != nil {
		return err
	}

	// Write to MemTable
	lsm.memTable.Put(key, value)
	lsm.memTable.size += int64(len(key) + len(value))

	// Check if MemTable is full
	if lsm.memTable.size >= lsm.config.MemTableSize {
		lsm.rotateMemTable()
	}

	return nil
}

// Get reads a value for a key
func (lsm *LSMTree) Get(key []byte) ([]byte, error) {
	lsm.mu.RLock()
	defer lsm.mu.RUnlock()

	if lsm.closed {
		return nil, ErrClosed
	}

	// 1. Check MemTable (latest writes)
	if value, found := lsm.memTable.Get(key); found {
		if value == nil {
			return nil, ErrKeyNotFound // Tombstone
		}
		return value, nil
	}

	// 2. Check immutable MemTables
	for _, imm := range lsm.immutables {
		if value, found := imm.Get(key); found {
			if value == nil {
				return nil, ErrKeyNotFound // Tombstone
			}
			return value, nil
		}
	}

	// 3. Check SSTables (L0 → L6)
	for _, level := range lsm.levels {
		level.mu.RLock()
		for _, sst := range level.sstables {
			// Use bloom filter to skip SSTables that don't have key
			if !sst.bloomFilter.MayContain(key) {
				continue
			}

			// Check block cache first
			if value, cached := lsm.blockCache.Get(sst.fileID, key); cached {
				if value == nil {
					level.mu.RUnlock()
					return nil, ErrKeyNotFound // Tombstone
				}
				level.mu.RUnlock()
				return value, nil
			}

			// Binary search in SSTable
			if value, found := sst.Get(key); found {
				lsm.blockCache.Put(sst.fileID, key, value)
				if value == nil {
					level.mu.RUnlock()
					return nil, ErrKeyNotFound // Tombstone
				}
				level.mu.RUnlock()
				return value, nil
			}
		}
		level.mu.RUnlock()
	}

	return nil, ErrKeyNotFound
}

// Delete marks a key as deleted (tombstone)
func (lsm *LSMTree) Delete(key []byte) error {
	return lsm.Put(key, nil) // nil value = tombstone
}

// Scan returns all keys in range [startKey, endKey)
func (lsm *LSMTree) Scan(startKey, endKey []byte) ([]KVPair, error) {
	lsm.mu.RLock()
	defer lsm.mu.RUnlock()

	if lsm.closed {
		return nil, ErrClosed
	}

	results := []KVPair{}
	seen := make(map[string]bool)

	// Scan MemTable
	lsm.memTable.Scan(startKey, endKey, func(key, value []byte) {
		if value != nil { // Skip tombstones
			results = append(results, KVPair{Key: key, Value: value})
			seen[string(key)] = true
		}
	})

	// Scan immutable MemTables
	for _, imm := range lsm.immutables {
		imm.Scan(startKey, endKey, func(key, value []byte) {
			if !seen[string(key)] {
				if value != nil { // Skip tombstones
					results = append(results, KVPair{Key: key, Value: value})
					seen[string(key)] = true
				}
			}
		})
	}

	// Scan SSTables
	for _, level := range lsm.levels {
		level.mu.RLock()
		for _, sst := range level.sstables {
			sst.Scan(startKey, endKey, func(key, value []byte) {
				if !seen[string(key)] {
					if value != nil { // Skip tombstones
						results = append(results, KVPair{Key: key, Value: value})
						seen[string(key)] = true
					}
				}
			})
		}
		level.mu.RUnlock()
	}

	return results, nil
}

// Apply applies a Raft command to the storage engine
func (lsm *LSMTree) Apply(command []byte) error {
	// In a complete implementation, this would deserialize the command
	// and apply the write operation
	// For now, we'll just treat it as a simple put operation
	if len(command) < 8 {
		return nil
	}
	// Simple format: [keyLen:4][key][value]
	// This is a placeholder implementation
	return nil
}

// rotateMemTable moves current MemTable to immutables and creates new one
func (lsm *LSMTree) rotateMemTable() {
	// Move current MemTable to immutables
	lsm.immutables = append(lsm.immutables, lsm.memTable)
	lsm.memTable = NewMemTable()

	// Clear WAL
	lsm.wal.Reset()

	// Trigger async flush
	lsm.wg.Add(1)
	go lsm.flushImmutable()
}

// flushImmutable flushes oldest immutable MemTable to L0 SSTable
func (lsm *LSMTree) flushImmutable() {
	defer lsm.wg.Done()

	lsm.mu.Lock()
	if len(lsm.immutables) == 0 {
		lsm.mu.Unlock()
		return
	}

	imm := lsm.immutables[0]
	lsm.mu.Unlock()

	// Create new SSTable file
	sst, err := lsm.writeSSTFile(imm, 0)
	if err != nil {
		// Log error but don't crash
		return
	}

	// Add to L0
	lsm.mu.Lock()
	lsm.levels[0].mu.Lock()
	lsm.levels[0].sstables = append(lsm.levels[0].sstables, sst)
	lsm.levels[0].mu.Unlock()
	lsm.immutables = lsm.immutables[1:]
	lsm.mu.Unlock()

	// Trigger compaction if L0 has too many files
	if len(lsm.levels[0].sstables) > 4 {
		lsm.compactor.TriggerCompaction(0)
	}
}

// writeSSTFile writes MemTable to SSTable file
func (lsm *LSMTree) writeSSTFile(memTable *MemTable, level int) (*SSTable, error) {
	fileID := uint64(time.Now().UnixNano())
	filePath := filepath.Join(lsm.config.DataDir, "sstable", formatSSTableFilename(fileID, level))

	if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
		return nil, err
	}

	sst, err := NewSSTableWriter(filePath, fileID, level, lsm.config.BloomFilterBits)
	if err != nil {
		return nil, err
	}

	// Write all entries from MemTable
	memTable.Scan(nil, nil, func(key, value []byte) {
		sst.Add(key, value)
	})

	return sst.Finalize()
}

// RunCompaction starts background compaction
func (lsm *LSMTree) RunCompaction() {
	lsm.compactor.Run()
}

// recoverFromWAL replays WAL to rebuild MemTable
func (lsm *LSMTree) recoverFromWAL() error {
	entries, err := lsm.wal.ReadAll()
	if err != nil {
		return err
	}

	for _, entry := range entries {
		if entry.Op == OpPut {
			lsm.memTable.Put(entry.Key, entry.Value)
		} else if entry.Op == OpDelete {
			lsm.memTable.Delete(entry.Key)
		}
	}

	return nil
}

// Close closes the LSM-tree
func (lsm *LSMTree) Close() error {
	lsm.mu.Lock()
	if lsm.closed {
		lsm.mu.Unlock()
		return nil
	}
	lsm.closed = true
	lsm.mu.Unlock()

	// Wait for pending operations
	lsm.wg.Wait()

	// Close WAL
	lsm.wal.Close()

	// Stop compactor
	lsm.compactor.Stop()

	return nil
}

// KVPair represents a key-value pair
type KVPair struct {
	Key   []byte
	Value []byte
}
