package vm

import (
	"encoding/binary"
	"errors"
	"fmt"
)

// Opcode represents VM operation
type Opcode byte

const (
	// Stack operations
	PUSH Opcode = 0x01
	POP  Opcode = 0x02
	DUP  Opcode = 0x03
	SWAP Opcode = 0x04

	// Arithmetic
	ADD Opcode = 0x10
	SUB Opcode = 0x11
	MUL Opcode = 0x12
	DIV Opcode = 0x13
	MOD Opcode = 0x14

	// Comparison
	EQ Opcode = 0x20
	LT Opcode = 0x21
	GT Opcode = 0x22

	// Logic
	AND Opcode = 0x30
	OR  Opcode = 0x31
	NOT Opcode = 0x32

	// Control flow
	JUMP   Opcode = 0x40
	JUMPI  Opcode = 0x41
	RETURN Opcode = 0x42

	// Storage
	SLOAD  Opcode = 0x50
	SSTORE Opcode = 0x51

	// Blockchain
	CALLER    Opcode = 0x60
	CALLVALUE Opcode = 0x61
	TIMESTAMP Opcode = 0x62
	BLOCKHASH Opcode = 0x63

	// System
	STOP Opcode = 0xff
)

// VM represents the virtual machine
type VM struct {
	stack       []uint64
	memory      []byte
	storage     map[uint64]uint64
	code        []byte
	pc          int // Program counter
	gas         uint64
	gasLimit    uint64
	caller      string
	callValue   uint64
	blockHeight uint64
	timestamp   int64
}

// NewVM creates a new VM instance
func NewVM(code []byte, gasLimit uint64) *VM {
	return &VM{
		stack:    make([]uint64, 0, 1024),
		memory:   make([]byte, 0),
		storage:  make(map[uint64]uint64),
		code:     code,
		pc:       0,
		gas:      0,
		gasLimit: gasLimit,
	}
}

// Execute runs the bytecode
func (vm *VM) Execute() ([]byte, error) {
	for vm.pc < len(vm.code) {
		// Check gas limit
		if vm.gas >= vm.gasLimit {
			return nil, errors.New("out of gas")
		}

		opcode := Opcode(vm.code[vm.pc])
		vm.pc++

		// Execute opcode
		if err := vm.executeOpcode(opcode); err != nil {
			return nil, err
		}

		// Charge gas
		vm.gas += vm.getGasCost(opcode)
	}

	// Return top of stack as result
	if len(vm.stack) > 0 {
		result := vm.stack[len(vm.stack)-1]
		return []byte(fmt.Sprintf("%d", result)), nil
	}

	return nil, nil
}

// executeOpcode performs opcode operation
func (vm *VM) executeOpcode(op Opcode) error {
	switch op {
	case PUSH:
		// Read next 8 bytes as uint64
		if vm.pc+8 > len(vm.code) {
			return errors.New("invalid PUSH: not enough bytes")
		}
		value := binary.BigEndian.Uint64(vm.code[vm.pc : vm.pc+8])
		vm.pc += 8
		vm.stack = append(vm.stack, value)

	case POP:
		if len(vm.stack) < 1 {
			return errors.New("stack underflow")
		}
		vm.stack = vm.stack[:len(vm.stack)-1]

	case DUP:
		if len(vm.stack) < 1 {
			return errors.New("stack underflow")
		}
		top := vm.stack[len(vm.stack)-1]
		vm.stack = append(vm.stack, top)

	case ADD:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a+b)

	case SUB:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a-b)

	case MUL:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a*b)

	case DIV:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		if b == 0 {
			return errors.New("division by zero")
		}
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a/b)

	case MOD:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		if b == 0 {
			return errors.New("modulo by zero")
		}
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a%b)

	case EQ:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		if a == b {
			vm.stack = append(vm.stack, 1)
		} else {
			vm.stack = append(vm.stack, 0)
		}

	case LT:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		if a < b {
			vm.stack = append(vm.stack, 1)
		} else {
			vm.stack = append(vm.stack, 0)
		}

	case GT:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		if a > b {
			vm.stack = append(vm.stack, 1)
		} else {
			vm.stack = append(vm.stack, 0)
		}

	case AND:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a&b)

	case OR:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		b := vm.stack[len(vm.stack)-1]
		a := vm.stack[len(vm.stack)-2]
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.stack = append(vm.stack, a|b)

	case NOT:
		if len(vm.stack) < 1 {
			return errors.New("stack underflow")
		}
		a := vm.stack[len(vm.stack)-1]
		vm.stack = vm.stack[:len(vm.stack)-1]
		if a == 0 {
			vm.stack = append(vm.stack, 1)
		} else {
			vm.stack = append(vm.stack, 0)
		}

	case SLOAD:
		if len(vm.stack) < 1 {
			return errors.New("stack underflow")
		}
		key := vm.stack[len(vm.stack)-1]
		vm.stack = vm.stack[:len(vm.stack)-1]
		value := vm.storage[key]
		vm.stack = append(vm.stack, value)

	case SSTORE:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		key := vm.stack[len(vm.stack)-2]
		value := vm.stack[len(vm.stack)-1]
		vm.stack = vm.stack[:len(vm.stack)-2]
		vm.storage[key] = value

	case JUMP:
		if len(vm.stack) < 1 {
			return errors.New("stack underflow")
		}
		dest := int(vm.stack[len(vm.stack)-1])
		vm.stack = vm.stack[:len(vm.stack)-1]
		if dest < 0 || dest >= len(vm.code) {
			return errors.New("invalid jump destination")
		}
		vm.pc = dest

	case JUMPI:
		if len(vm.stack) < 2 {
			return errors.New("stack underflow")
		}
		dest := int(vm.stack[len(vm.stack)-2])
		cond := vm.stack[len(vm.stack)-1]
		vm.stack = vm.stack[:len(vm.stack)-2]
		if cond != 0 {
			if dest < 0 || dest >= len(vm.code) {
				return errors.New("invalid jump destination")
			}
			vm.pc = dest
		}

	case RETURN:
		// Execution finished
		vm.pc = len(vm.code)

	case STOP:
		vm.pc = len(vm.code)

	default:
		return fmt.Errorf("unknown opcode: 0x%x", op)
	}

	return nil
}

// getGasCost returns gas cost for opcode
func (vm *VM) getGasCost(op Opcode) uint64 {
	switch op {
	case PUSH, POP, DUP:
		return 1
	case ADD, SUB, MUL, DIV, MOD:
		return 3
	case SLOAD:
		return 50
	case SSTORE:
		return 100
	default:
		return 1
	}
}

// GetStorage returns VM storage
func (vm *VM) GetStorage() map[uint64]uint64 {
	return vm.storage
}
