package transform

import (
	"fmt"
	"strconv"
	"strings"

	"github.com/yourusername/etl-pipeline/pkg/models"
)

// Filter represents a filtering condition
type Filter struct {
	expression string
}

// NewFilter creates a new filter
func NewFilter(expression string) *Filter {
	return &Filter{expression: expression}
}

// Apply applies the filter to a record
func (f *Filter) Apply(record *models.Record) (bool, error) {
	// Simple expression parser: "field operator value"
	// Examples: "age > 18", "status == 'active'", "price <= 100.0"

	parts := strings.Fields(f.expression)
	if len(parts) < 3 {
		return false, fmt.Errorf("invalid filter expression: %s", f.expression)
	}

	fieldName := parts[0]
	operator := parts[1]
	valueStr := strings.Join(parts[2:], " ")

	// Remove quotes from string values
	valueStr = strings.Trim(valueStr, "'\"")

	fieldValue, ok := record.Get(fieldName)
	if !ok {
		return false, fmt.Errorf("field %s not found", fieldName)
	}

	return f.compare(fieldValue, operator, valueStr)
}

func (f *Filter) compare(fieldValue interface{}, operator, valueStr string) (bool, error) {
	switch operator {
	case "==", "=":
		return f.equals(fieldValue, valueStr), nil
	case "!=":
		return !f.equals(fieldValue, valueStr), nil
	case ">":
		return f.greaterThan(fieldValue, valueStr)
	case ">=":
		return f.greaterThanOrEqual(fieldValue, valueStr)
	case "<":
		return f.lessThan(fieldValue, valueStr)
	case "<=":
		return f.lessThanOrEqual(fieldValue, valueStr)
	default:
		return false, fmt.Errorf("unknown operator: %s", operator)
	}
}

func (f *Filter) equals(fieldValue interface{}, valueStr string) bool {
	switch v := fieldValue.(type) {
	case string:
		return v == valueStr
	case int:
		val, _ := strconv.Atoi(valueStr)
		return v == val
	case int64:
		val, _ := strconv.ParseInt(valueStr, 10, 64)
		return v == val
	case float64:
		val, _ := strconv.ParseFloat(valueStr, 64)
		return v == val
	case bool:
		val, _ := strconv.ParseBool(valueStr)
		return v == val
	default:
		return fmt.Sprintf("%v", v) == valueStr
	}
}

func (f *Filter) greaterThan(fieldValue interface{}, valueStr string) (bool, error) {
	switch v := fieldValue.(type) {
	case int:
		val, err := strconv.Atoi(valueStr)
		if err != nil {
			return false, err
		}
		return v > val, nil
	case int64:
		val, err := strconv.ParseInt(valueStr, 10, 64)
		if err != nil {
			return false, err
		}
		return v > val, nil
	case float64:
		val, err := strconv.ParseFloat(valueStr, 64)
		if err != nil {
			return false, err
		}
		return v > val, nil
	default:
		return false, fmt.Errorf("cannot compare %T with >", v)
	}
}

func (f *Filter) greaterThanOrEqual(fieldValue interface{}, valueStr string) (bool, error) {
	switch v := fieldValue.(type) {
	case int:
		val, err := strconv.Atoi(valueStr)
		if err != nil {
			return false, err
		}
		return v >= val, nil
	case int64:
		val, err := strconv.ParseInt(valueStr, 10, 64)
		if err != nil {
			return false, err
		}
		return v >= val, nil
	case float64:
		val, err := strconv.ParseFloat(valueStr, 64)
		if err != nil {
			return false, err
		}
		return v >= val, nil
	default:
		return false, fmt.Errorf("cannot compare %T with >=", v)
	}
}

func (f *Filter) lessThan(fieldValue interface{}, valueStr string) (bool, error) {
	switch v := fieldValue.(type) {
	case int:
		val, err := strconv.Atoi(valueStr)
		if err != nil {
			return false, err
		}
		return v < val, nil
	case int64:
		val, err := strconv.ParseInt(valueStr, 10, 64)
		if err != nil {
			return false, err
		}
		return v < val, nil
	case float64:
		val, err := strconv.ParseFloat(valueStr, 64)
		if err != nil {
			return false, err
		}
		return v < val, nil
	default:
		return false, fmt.Errorf("cannot compare %T with <", v)
	}
}

func (f *Filter) lessThanOrEqual(fieldValue interface{}, valueStr string) (bool, error) {
	switch v := fieldValue.(type) {
	case int:
		val, err := strconv.Atoi(valueStr)
		if err != nil {
			return false, err
		}
		return v <= val, nil
	case int64:
		val, err := strconv.ParseInt(valueStr, 10, 64)
		if err != nil {
			return false, err
		}
		return v <= val, nil
	case float64:
		val, err := strconv.ParseFloat(valueStr, 64)
		if err != nil {
			return false, err
		}
		return v <= val, nil
	default:
		return false, fmt.Errorf("cannot compare %T with <=", v)
	}
}
