package transaction

import (
	"context"
	"fmt"
	"sync"
	"time"
)

// TxnStatus represents transaction status
type TxnStatus int

const (
	TxnActive TxnStatus = iota
	TxnPreparing
	TxnPrepared
	TxnCommitting
	TxnCommitted
	TxnAborting
	TxnAborted
)

// Transaction represents a distributed transaction
type Transaction struct {
	ID           string
	Status       TxnStatus
	Participants []int
	Writes       map[int][]Write
	mu           sync.RWMutex
}

// Write represents a write operation
type Write struct {
	Key   []byte
	Value []byte
}

// Coordinator manages distributed transactions
type Coordinator struct {
	transactions sync.Map
	raft         RaftNode
	storage      Storage
	stopChan     chan struct{}
	wg           sync.WaitGroup
}

// RaftNode interface
type RaftNode interface {
	Replicate(command []byte) error
	IsLeader() bool
}

// Storage interface
type Storage interface {
	Put(key, value []byte) error
	Get(key []byte) ([]byte, error)
	Delete(key []byte) error
}

// NewCoordinator creates a new transaction coordinator
func NewCoordinator(raft RaftNode, storage Storage) *Coordinator {
	return &Coordinator{
		raft:     raft,
		storage:  storage,
		stopChan: make(chan struct{}),
	}
}

// Run starts the coordinator
func (c *Coordinator) Run() {
	c.wg.Add(1)
	go c.cleanupOldTransactions()
}

func (c *Coordinator) cleanupOldTransactions() {
	defer c.wg.Done()

	ticker := time.NewTicker(60 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			// Clean up old completed transactions
		case <-c.stopChan:
			return
		}
	}
}

// Begin starts a new transaction
func (c *Coordinator) Begin() *Transaction {
	txn := &Transaction{
		ID:           fmt.Sprintf("tx-%d", time.Now().UnixNano()),
		Status:       TxnActive,
		Participants: []int{},
		Writes:       make(map[int][]Write),
	}

	c.transactions.Store(txn.ID, txn)
	return txn
}

// Write buffers a write in the transaction
func (c *Coordinator) Write(txn *Transaction, shardID int, key, value []byte) {
	txn.mu.Lock()
	defer txn.mu.Unlock()

	txn.Writes[shardID] = append(txn.Writes[shardID], Write{Key: key, Value: value})

	// Track participant shards
	found := false
	for _, id := range txn.Participants {
		if id == shardID {
			found = true
			break
		}
	}
	if !found {
		txn.Participants = append(txn.Participants, shardID)
	}
}

// Commit commits a transaction using 2PC
func (c *Coordinator) Commit(ctx context.Context, txn *Transaction) error {
	txn.mu.Lock()
	txn.Status = TxnPreparing
	txn.mu.Unlock()

	// Phase 1: PREPARE
	// In real implementation, send PREPARE to all shards
	// For demo, just check locally
	allPrepared := true

	if !allPrepared {
		txn.mu.Lock()
		txn.Status = TxnAborted
		txn.mu.Unlock()
		return ErrTransactionAborted
	}

	// Phase 2: COMMIT
	txn.mu.Lock()
	txn.Status = TxnCommitting
	txn.mu.Unlock()

	// Apply all writes
	for _, writes := range txn.Writes {
		for _, w := range writes {
			if err := c.storage.Put(w.Key, w.Value); err != nil {
				txn.mu.Lock()
				txn.Status = TxnAborted
				txn.mu.Unlock()
				return err
			}
		}
	}

	txn.mu.Lock()
	txn.Status = TxnCommitted
	txn.mu.Unlock()

	return nil
}

// Stop stops the coordinator
func (c *Coordinator) Stop() {
	close(c.stopChan)
	c.wg.Wait()
}

var ErrTransactionAborted = &TxnError{"transaction aborted"}

type TxnError struct {
	msg string
}

func (e *TxnError) Error() string {
	return e.msg
}
