package storage

import (
	"bytes"
	"math/rand"
	"sync"
)

const maxLevel = 16

// MemTable is an in-memory sorted data structure using skip list
type MemTable struct {
	head *skipListNode
	size int64
	mu   sync.RWMutex
}

type skipListNode struct {
	key     []byte
	value   []byte
	forward []*skipListNode
}

// NewMemTable creates a new MemTable
func NewMemTable() *MemTable {
	return &MemTable{
		head: &skipListNode{
			forward: make([]*skipListNode, maxLevel),
		},
		size: 0,
	}
}

// Put inserts a key-value pair
func (mt *MemTable) Put(key, value []byte) {
	mt.mu.Lock()
	defer mt.mu.Unlock()

	update := make([]*skipListNode, maxLevel)
	current := mt.head

	// Find insertion point
	for i := maxLevel - 1; i >= 0; i-- {
		for current.forward[i] != nil && bytes.Compare(current.forward[i].key, key) < 0 {
			current = current.forward[i]
		}
		update[i] = current
	}

	// Check if key already exists
	current = current.forward[0]
	if current != nil && bytes.Equal(current.key, key) {
		current.value = value
		return
	}

	// Insert new node
	level := randomLevel()
	node := &skipListNode{
		key:     key,
		value:   value,
		forward: make([]*skipListNode, level),
	}

	for i := 0; i < level; i++ {
		node.forward[i] = update[i].forward[i]
		update[i].forward[i] = node
	}
}

// Get retrieves a value by key
func (mt *MemTable) Get(key []byte) ([]byte, bool) {
	mt.mu.RLock()
	defer mt.mu.RUnlock()

	current := mt.head
	for i := maxLevel - 1; i >= 0; i-- {
		for current.forward[i] != nil && bytes.Compare(current.forward[i].key, key) < 0 {
			current = current.forward[i]
		}
	}

	current = current.forward[0]
	if current != nil && bytes.Equal(current.key, key) {
		return current.value, true
	}

	return nil, false
}

// Delete marks a key as deleted
func (mt *MemTable) Delete(key []byte) {
	mt.Put(key, nil) // nil = tombstone
}

// Scan iterates over keys in range [startKey, endKey)
func (mt *MemTable) Scan(startKey, endKey []byte, fn func(key, value []byte)) {
	mt.mu.RLock()
	defer mt.mu.RUnlock()

	current := mt.head.forward[0]
	for current != nil {
		if startKey != nil && bytes.Compare(current.key, startKey) < 0 {
			current = current.forward[0]
			continue
		}
		if endKey != nil && bytes.Compare(current.key, endKey) >= 0 {
			break
		}
		fn(current.key, current.value)
		current = current.forward[0]
	}
}

func randomLevel() int {
	level := 1
	for level < maxLevel && rand.Float64() < 0.5 {
		level++
	}
	return level
}
