package main

import (
	"bufio"
	"context"
	"flag"
	"fmt"
	"os"
	"strings"
	"time"

	"github.com/distributed-db/pkg/client"
)

func main() {
	addr := flag.String("addr", "localhost:9001", "Database server address")
	flag.Parse()

	fmt.Printf("Connecting to database at %s...\n", *addr)

	// Connect to database
	db, err := client.Connect(*addr)
	if err != nil {
		fmt.Printf("Failed to connect: %v\n", err)
		os.Exit(1)
	}
	defer db.Close()

	fmt.Println("Connected! Type 'help' for commands.")

	// Interactive REPL
	scanner := bufio.NewScanner(os.Stdin)
	var currentTxn *client.Transaction

	for {
		if currentTxn != nil {
			fmt.Print("tx> ")
		} else {
			fmt.Print("> ")
		}

		if !scanner.Scan() {
			break
		}

		line := strings.TrimSpace(scanner.Text())
		if line == "" {
			continue
		}

		parts := strings.Fields(line)
		if len(parts) == 0 {
			continue
		}

		cmd := strings.ToUpper(parts[0])

		ctx := context.Background()

		switch cmd {
		case "HELP":
			printHelp()

		case "SET":
			if len(parts) < 3 {
				fmt.Println("Usage: SET <key> <value>")
				continue
			}
			key := parts[1]
			value := strings.Join(parts[2:], " ")

			start := time.Now()
			if currentTxn != nil {
				currentTxn.Put([]byte(key), []byte(value))
				fmt.Println("OK (buffered)")
			} else {
				err := db.Put(ctx, []byte(key), []byte(value))
				if err != nil {
					fmt.Printf("Error: %v\n", err)
				} else {
					fmt.Printf("OK (%v)\n", time.Since(start))
				}
			}

		case "GET":
			if len(parts) < 2 {
				fmt.Println("Usage: GET <key>")
				continue
			}
			key := parts[1]

			start := time.Now()
			value, err := db.Get(ctx, []byte(key))
			if err != nil {
				fmt.Printf("Error: %v\n", err)
			} else {
				fmt.Printf("%q (%v)\n", string(value), time.Since(start))
			}

		case "DELETE":
			if len(parts) < 2 {
				fmt.Println("Usage: DELETE <key>")
				continue
			}
			key := parts[1]

			start := time.Now()
			if currentTxn != nil {
				currentTxn.Delete([]byte(key))
				fmt.Println("OK (buffered)")
			} else {
				err := db.Delete(ctx, []byte(key))
				if err != nil {
					fmt.Printf("Error: %v\n", err)
				} else {
					fmt.Printf("OK (%v)\n", time.Since(start))
				}
			}

		case "SCAN":
			if len(parts) < 3 {
				fmt.Println("Usage: SCAN <start_key> <end_key>")
				continue
			}
			startKey := parts[1]
			endKey := parts[2]

			start := time.Now()
			results, err := db.Scan(ctx, []byte(startKey), []byte(endKey))
			if err != nil {
				fmt.Printf("Error: %v\n", err)
			} else {
				for _, kv := range results {
					fmt.Printf("%s = %q\n", string(kv.Key), string(kv.Value))
				}
				fmt.Printf("(%d keys in %v)\n", len(results), time.Since(start))
			}

		case "BEGIN":
			if currentTxn != nil {
				fmt.Println("Error: Transaction already in progress")
				continue
			}
			currentTxn = db.Begin()
			fmt.Printf("Transaction ID: %s\n", currentTxn.ID())

		case "COMMIT":
			if currentTxn == nil {
				fmt.Println("Error: No transaction in progress")
				continue
			}
			start := time.Now()
			err := currentTxn.Commit(ctx)
			if err != nil {
				fmt.Printf("Error: %v\n", err)
			} else {
				fmt.Printf("OK (committed in %v)\n", time.Since(start))
			}
			currentTxn = nil

		case "ROLLBACK":
			if currentTxn == nil {
				fmt.Println("Error: No transaction in progress")
				continue
			}
			currentTxn.Rollback()
			fmt.Println("OK (rolled back)")
			currentTxn = nil

		case "EXIT", "QUIT":
			fmt.Println("Goodbye!")
			return

		default:
			fmt.Printf("Unknown command: %s (type 'help' for commands)\n", cmd)
		}
	}
}

func printHelp() {
	fmt.Println(`
Available commands:
  SET <key> <value>       - Set a key-value pair
  GET <key>               - Get value for key
  DELETE <key>            - Delete a key
  SCAN <start> <end>      - Scan range of keys
  BEGIN                   - Begin transaction
  COMMIT                  - Commit transaction
  ROLLBACK                - Rollback transaction
  HELP                    - Show this help
  EXIT                    - Exit client
`)
}
