#' Calculate Feature Importance for Anomalies
#'
#' Calculates which feature contributes most to each record's anomaly score.
#' This provides a "reason code" explaining why each record was flagged as anomalous.
#'
#' @importFrom stats median mad var
#' @param flagged_data A data frame with anomaly scores and is_anomaly flags,
#'   typically the output of \code{flag_top_anomalies()}.
#' @param metadata Metadata from \code{prep_for_anomaly()}, containing
#'   information about numeric and categorical columns.
#' @param top_k Integer indicating how many top contributing features to consider.
#'   Default is 1 (returns only the most important feature).
#' @param max_cols Integer indicating maximum number of columns to consider for
#'   feature importance. If NULL, uses all columns. Default is 10 for performance.
#'
#' @return The input data frame with additional columns:
#'   \item{reason_feature}{Name of the feature contributing most to the anomaly}
#'   \item{reason_value}{The value of that feature for this record}
#'   \item{reason_code}{A brief description combining feature name and value}
#'   \item{reason_deviation}{The standardized deviation from the median (for numeric) or frequency (for categorical)}
#'
#' @export
#'
#' @examples
#' \donttest{
#' data <- data.frame(
#'   patient_id = 1:50,
#'   age = rnorm(50, 50, 15),
#'   cost = rnorm(50, 10000, 5000)
#' )
#' scored_data <- score_anomaly(data, id_cols = "patient_id")
#' flagged_data <- flag_top_anomalies(scored_data)
#' metadata <- attr(scored_data, "metadata")
#' flagged_data <- calculate_feature_importance(flagged_data, metadata)
#' }
calculate_feature_importance <- function(flagged_data, metadata, top_k = 1, max_cols = 10) {
  
  # Validate inputs
  if (!is.data.frame(flagged_data)) {
    stop("flagged_data must be a data frame")
  }
  
  if (is.null(metadata)) {
    stop("metadata must be provided")
  }
  
  numeric_cols <- metadata$numeric_cols
  categorical_cols <- metadata$categorical_cols
  
  numeric_cols <- numeric_cols[numeric_cols %in% names(flagged_data)]
  categorical_cols <- categorical_cols[categorical_cols %in% names(flagged_data)]
  
  # Limit columns for performance - prioritize high variance numeric columns
  if (!is.null(max_cols) && length(numeric_cols) + length(categorical_cols) > max_cols) {
    # Calculate variance for numeric columns and select top ones
    if (length(numeric_cols) > 0) {
      variances <- sapply(numeric_cols, function(col) {
        stats::var(flagged_data[[col]], na.rm = TRUE)
      })
      # Select top numeric columns by variance (up to 70% of max_cols)
      n_numeric <- min(ceiling(max_cols * 0.7), length(numeric_cols))
      numeric_cols <- names(sort(variances, decreasing = TRUE))[1:n_numeric]
    }
    # Select top categorical columns (up to remaining 30% of max_cols)
    n_categorical <- min(max_cols - length(numeric_cols), length(categorical_cols))
    if (n_categorical > 0 && length(categorical_cols) > n_categorical) {
      # For categorical, prefer columns with more unique values (more informative)
      unique_counts <- sapply(categorical_cols, function(col) {
        length(unique(flagged_data[[col]][!is.na(flagged_data[[col]])]))
      })
      categorical_cols <- names(sort(unique_counts, decreasing = TRUE))[1:n_categorical]
    } else if (n_categorical <= 0) {
      categorical_cols <- character(0)
    }
  }
  
  flagged_data$reason_feature <- NA_character_
  flagged_data$reason_value <- NA_character_
  flagged_data$reason_code <- NA_character_
  flagged_data$reason_deviation <- NA_real_
  
  anomaly_indices <- which(flagged_data$is_anomaly == TRUE)
  
  if (length(anomaly_indices) == 0) {
    return(flagged_data)
  }
  
  # Pre-calculate statistics for ALL data (much faster than recalculating for each anomaly)
  numeric_stats <- list()
  for (col in numeric_cols) {
    if (col %in% names(flagged_data)) {
      all_values <- flagged_data[[col]]
      all_values <- all_values[!is.na(all_values)]
      if (length(all_values) > 0) {
        numeric_stats[[col]] <- list(
          median = stats::median(all_values, na.rm = TRUE),
          mad = stats::mad(all_values, na.rm = TRUE)
        )
      }
    }
  }
  
  categorical_stats <- list()
  for (col in categorical_cols) {
    if (col %in% names(flagged_data)) {
      all_values <- flagged_data[[col]]
      all_values <- all_values[!is.na(all_values)]
      if (length(all_values) > 0) {
        category_counts <- table(all_values)
        total_count <- length(all_values)
        categorical_stats[[col]] <- list(
          counts = category_counts,
          total = total_count
        )
      }
    }
  }
  
  # Now process each anomaly (much faster since stats are pre-calculated)
  for (idx in anomaly_indices) {
    record <- flagged_data[idx, ]
    contributions <- list()
    
    # Calculate deviation for numeric features using pre-calculated MAD
    for (col in numeric_cols) {
      if (col %in% names(numeric_stats) && !is.na(record[[col]]) && is.numeric(record[[col]])) {
        stats <- numeric_stats[[col]]
        if (stats$mad > 0) {
          z_score <- abs((record[[col]] - stats$median) / stats$mad)
          contributions[[col]] <- list(
            feature = col,
            value = record[[col]],
            deviation = z_score,
            type = "numeric"
          )
        }
      }
    }
    
    # Calculate rarity for categorical features using pre-calculated frequencies
    for (col in categorical_cols) {
      if (col %in% names(categorical_stats) && !is.na(record[[col]])) {
        stats <- categorical_stats[[col]]
        category_freq <- stats$counts[as.character(record[[col]])]
        
        if (is.na(category_freq)) {
          category_freq <- 0
        }
        
        inverse_freq <- 1 - (category_freq / stats$total)
        
        contributions[[col]] <- list(
          feature = col,
          value = as.character(record[[col]]),
          deviation = inverse_freq,
          type = "categorical"
        )
      }
    }
    
    if (length(contributions) > 0) {
      deviations <- sapply(contributions, function(x) x$deviation)
      top_indices <- order(deviations, decreasing = TRUE)[1:min(top_k, length(contributions))]
      top_contrib <- contributions[[top_indices[1]]]
      
      if (top_contrib$type == "numeric") {
        if (abs(top_contrib$value) >= 1000) {
          value_str <- sprintf("%.0f", top_contrib$value)
        } else if (abs(top_contrib$value) >= 1) {
          value_str <- sprintf("%.2f", top_contrib$value)
        } else {
          value_str <- sprintf("%.4f", top_contrib$value)
        }
        
        # Use pre-calculated median instead of recalculating
        col_median <- numeric_stats[[top_contrib$feature]]$median
        direction <- if (top_contrib$value > col_median) "High" else "Low"
        reason_code <- paste0(direction, " ", top_contrib$feature, " (", value_str, ")")
      } else {
        reason_code <- paste0("Rare ", top_contrib$feature, ": ", top_contrib$value)
      }
      
      flagged_data$reason_feature[idx] <- top_contrib$feature
      flagged_data$reason_value[idx] <- as.character(top_contrib$value)
      flagged_data$reason_code[idx] <- reason_code
      flagged_data$reason_deviation[idx] <- top_contrib$deviation
    }
  }
  
  return(flagged_data)
}

