#' Select Best Model via Bayesian Model Averaging
#'
#' Fits multiple bivariate hurdle models across a grid of lag orders and
#' horseshoe hyperparameters, then performs model selection using LOO-CV
#' and stacking weights.
#'
#' @param DT A data.table with the data.
#' @param spec Character; model specification ("A", "B", "C", "D").
#' @param controls Character vector of control variable names.
#' @param k_grid Integer vector of lag orders to evaluate.
#' @param hs_grid Data.frame with columns hs_tau0, hs_slab_scale, hs_slab_df
#'   defining the horseshoe hyperparameter grid.
#' @param model A compiled CmdStan model. If NULL, loads the default.
#' @param output_base_dir Base directory for output files. If NULL, uses tempdir().
#' @param iter_warmup Integer; warmup iterations.
#' @param iter_sampling Integer; sampling iterations.
#' @param chains Integer; number of chains.
#' @param seed Integer; random seed.
#' @param use_parallel Logical; if TRUE and furrr is available, fits models in parallel.
#' @param verbose Logical; print progress messages.
#'
#' @return A list with components:
#'   \item{fits}{List of fitted model objects.}
#'   \item{loos}{List of LOO objects.}
#'   \item{weights}{Numeric vector of stacking weights.}
#'   \item{table}{Data.frame with results sorted by ELPD.}
#'
#' @export
#'
#' @examples
#' \donttest{
#' library(data.table)
#'
#' # 1. Create a COMPLETE dummy dataset
#' # select_by_bma -> fit_one -> build_design requires ALL these columns:
#' DT <- data.table(
#'   year = 2000:2020,
#'   I = rpois(21, lambda = 4),
#'   C = rpois(21, lambda = 3),
#'   zI = rnorm(21),
#'   zC = rnorm(21),
#'   t_norm = seq(-1, 1, length.out = 21),
#'   t_poly2 = seq(-1, 1, length.out = 21)^2,
#'   Regime = factor(sample(c("A", "B"), 21, replace = TRUE)),
#'   trans_PS = sample(0:1, 21, replace = TRUE),
#'   trans_SF = sample(0:1, 21, replace = TRUE),
#'   trans_FC = sample(0:1, 21, replace = TRUE),
#'   log_exposure50 = rep(0, 21)
#' )
#'
#' # 2. Run the function
#' # IMPORTANT: use_parallel = FALSE to avoid complexity/errors in CRAN checks
#' # We reduce the grid size (k_grid=0) for speed in this example
#' try({
#'   result <- select_by_bma(
#'     DT, 
#'     spec = "C", 
#'     k_grid = 0, 
#'     hs_grid = data.frame(hs_tau0=0.5, hs_slab_scale=1, hs_slab_df=4),
#'     use_parallel = FALSE,
#'     iter_warmup = 100, iter_sampling = 100, chains = 1 # Minimal MCMC for speed
#'   )
#'   
#'   if (!is.null(result$table)) {
#'     print(result$table)
#'   }
#' })
#' }
select_by_bma <- function(
    DT,
    spec = "C",
    controls = character(0),
    k_grid = 0:3,
    hs_grid = data.frame(
      hs_tau0 = c(0.1, 0.5, 1.0),
      hs_slab_scale = c(1, 5, 1, 5, 1, 5),
      hs_slab_df = 4,
      stringsAsFactors = FALSE
    ),
    model = NULL,
    output_base_dir = NULL,
    iter_warmup = 900,
    iter_sampling = 1200,
    chains = 4,
    seed = 123,
    use_parallel = TRUE,
    verbose = TRUE
) {
  
  if (!requireNamespace("loo", quietly = TRUE)) {
    stop("Package 'loo' is required.")
  }
  
  if (is.null(hs_grid) || nrow(hs_grid) < 1) {
    hs_grid <- data.frame(
      hs_tau0 = c(0.1, 0.5, 1.0),
      hs_slab_scale = c(1, 5),
      hs_slab_df = 4
    )
    hs_grid <- expand.grid(
      hs_tau0 = c(0.1, 0.5, 1.0),
      hs_slab_scale = c(1, 5),
      hs_slab_df = 4,
      stringsAsFactors = FALSE
    )
  }
  
  if (is.null(model)) {
    model <- get_hurdle_model()
  }
  
  if (is.null(output_base_dir)) {
    output_base_dir <- tempdir()
  }
  
  param_grid <- expand.grid(
    k = k_grid,
    hs_idx = seq_len(nrow(hs_grid)),
    stringsAsFactors = FALSE
  )
  param_grid$hs_tau0 <- hs_grid$hs_tau0[param_grid$hs_idx]
  param_grid$hs_slab_scale <- hs_grid$hs_slab_scale[param_grid$hs_idx]
  param_grid$hs_slab_df <- hs_grid$hs_slab_df[param_grid$hs_idx]
  param_grid$fit_id <- seq_len(nrow(param_grid))
  
  total_fits <- nrow(param_grid)
  
  if (verbose) {
    message(sprintf("Fitting %d models (k_grid: %s, hs_grid: %d rows)",
                    total_fits, paste(k_grid, collapse = ","), nrow(hs_grid)))
  }
  
  fit_single <- function(i) {
    row <- param_grid[i, , drop = FALSE]
    
    out_dir <- file.path(output_base_dir,
                         sprintf("fit_%s_k%d_hs%d_%d",
                                 spec, row$k, row$hs_idx, row$fit_id))
    
    f <- fit_one(
      DT = DT,
      k = as.integer(row$k),
      spec = spec,
      controls = controls,
      model = model,
      output_dir = out_dir,
      iter_warmup = iter_warmup,
      iter_sampling = iter_sampling,
      chains = chains,
      seed = seed + row$fit_id,
      hs_tau0 = as.numeric(row$hs_tau0),
      hs_slab_scale = as.numeric(row$hs_slab_scale),
      hs_slab_df = as.numeric(row$hs_slab_df),
      verbose = FALSE
    )
    
    log_lik <- tryCatch(
      as.matrix(f$fit$draws("log_lik_joint", format = "draws_matrix")),
      error = function(e) {
        tryCatch(
          as.matrix(f$fit$draws("log_lik", format = "draws_matrix")),
          error = function(e2) NULL
        )
      }
    )
    
    if (is.null(log_lik) || !is.matrix(log_lik) || ncol(log_lik) < 1) {
      return(list(
        fit = f,
        loo = NULL,
        params = row,
        pareto_k = NA_real_,
        k_max = NA_real_,
        k_frac_bad = NA_real_,
        rhat_max = NA_real_,
        ess_bulk_min = NA_real_,
        ess_tail_min = NA_real_,
        n_divergences = NA_integer_,
        treedepth_max = NA_real_
      ))
    }
    
    log_lik[!is.finite(log_lik)] <- -1e10
    
    l <- tryCatch(
      loo::loo(log_lik, cores = 1),
      error = function(e) {
        T_eff <- ncol(log_lik)
        col_max <- apply(log_lik, 2, max)
        shifted <- sweep(log_lik, 2, col_max, "-")
        logmeanexp <- col_max + log(colMeans(exp(shifted)))
        elpd_est <- sum(logmeanexp)
        elpd_se <- stats::sd(logmeanexp) * sqrt(T_eff)
        
        structure(
          list(
            estimates = matrix(
              c(elpd_est, elpd_se, NA_real_, NA_real_),
              nrow = 2, ncol = 2,
              dimnames = list(c("elpd_loo", "p_loo"), c("Estimate", "SE"))
            ),
            pointwise = data.frame(elpd_loo = logmeanexp),
            diagnostics = list(pareto_k = rep(NA_real_, length(logmeanexp)))
          ),
          class = "loo"
        )
      }
    )
    
    pareto_k <- tryCatch(
      as.numeric(l$diagnostics$pareto_k),
      error = function(e) rep(NA_real_, ncol(log_lik))
    )
    k_bad <- is.finite(pareto_k) & (pareto_k > 0.7)
    
    if (any(k_bad, na.rm = TRUE)) {
      l <- tryCatch(
        loo::loo(log_lik, moment_match = TRUE, cores = 1),
        error = function(e) l
      )
      pareto_k <- tryCatch(
        as.numeric(l$diagnostics$pareto_k),
        error = function(e) pareto_k
      )
    }
    
    summ <- tryCatch(f$fit$summary(), error = function(e) NULL)
    rhat_max <- if (is.null(summ)) NA_real_ else suppressWarnings(max(summ$rhat, na.rm = TRUE))
    ess_bulk_min <- if (is.null(summ)) NA_real_ else suppressWarnings(min(summ$ess_bulk, na.rm = TRUE))
    ess_tail_min <- if (is.null(summ)) NA_real_ else suppressWarnings(min(summ$ess_tail, na.rm = TRUE))
    
    sd_diag <- tryCatch(f$fit$sampler_diagnostics(), error = function(e) NULL)
    
    n_divergences <- NA_integer_
    treedepth_max <- NA_real_
    
    if (!is.null(sd_diag)) {
      if (is.array(sd_diag)) {
        n_divergences <- sum(sd_diag[, , "divergent__"], na.rm = TRUE)
        treedepth_max <- max(sd_diag[, , "treedepth__"], na.rm = TRUE)
      } else if (is.list(sd_diag)) {
        n_divergences <- sum(vapply(sd_diag, function(df) {
          sum(df[["divergent__"]], na.rm = TRUE)
        }, 0L))
        treedepth_max <- max(vapply(sd_diag, function(df) {
          max(df[["treedepth__"]], na.rm = TRUE)
        }, -Inf))
      }
    }
    
    k_max <- suppressWarnings(max(pareto_k, na.rm = TRUE))
    k_frac_bad <- suppressWarnings(mean(pareto_k > 0.7, na.rm = TRUE))
    
    list(
      fit = f,
      loo = l,
      params = row,
      pareto_k = pareto_k,
      k_max = k_max,
      k_frac_bad = k_frac_bad,
      rhat_max = rhat_max,
      ess_bulk_min = ess_bulk_min,
      ess_tail_min = ess_tail_min,
      n_divergences = n_divergences,
      treedepth_max = treedepth_max
    )
  }
  
  can_parallel <- use_parallel &&
    requireNamespace("furrr", quietly = TRUE) &&
    requireNamespace("future", quietly = TRUE)
  
  if (can_parallel) {
    if (verbose) message("Using parallel execution with furrr")
    
    results <- furrr::future_map(
      seq_len(total_fits),
      fit_single,
      .options = furrr::furrr_options(
        seed = TRUE,
        packages = c("cmdstanr", "posterior", "loo", "dplyr", "data.table")
      ),
      .progress = verbose
    )
  } else {
    if (verbose) message("Using sequential execution")
    
    results <- lapply(seq_len(total_fits), function(i) {
      if (verbose) message(sprintf("  Fitting model %d/%d", i, total_fits))
      fit_single(i)
    })
  }
  
  fits <- lapply(results, `[[`, "fit")
  loos <- lapply(results, `[[`, "loo")
  
  valid_loos <- Filter(function(x) !is.null(x), loos)
  
  if (length(valid_loos) == 0) {
    warning("No valid LOO objects obtained. Using uniform weights.")
    w <- rep(1 / length(loos), length(loos))
  } else {
    has_pareto <- function(L) {
      !is.null(L) && !is.null(L$diagnostics) && !is.null(L$diagnostics$pareto_k)
    }
    all_have_pareto <- all(vapply(valid_loos, has_pareto, logical(1)))
    
    if (!all_have_pareto) {
      if (verbose) message("Some LOO objects lack PSIS diagnostics. Using uniform weights.")
      w <- rep(1 / length(loos), length(loos))
    } else {
      w <- tryCatch(
        loo::loo_model_weights(valid_loos, method = "stacking"),
        error = function(e) {
          if (verbose) message("Stacking failed: using uniform weights.")
          rep(1 / length(valid_loos), length(valid_loos))
        }
      )
      
      if (length(w) < length(loos)) {
        w_full <- rep(0, length(loos))
        valid_idx <- which(vapply(loos, function(x) !is.null(x), logical(1)))
        w_full[valid_idx] <- w
        w <- w_full
      }
    }
  }
  
  extract_elpd <- function(L) {
    if (is.null(L)) return(c(NA_real_, NA_real_))
    est <- tryCatch(L$estimates["elpd_loo", ], error = function(e) c(NA_real_, NA_real_))
    c(est["Estimate"], est["SE"])
  }
  
  elpd_vals <- t(sapply(loos, extract_elpd))
  
  diag_tab <- data.frame(
    fit_id = param_grid$fit_id,
    k_max = vapply(results, function(r) as.numeric(r$k_max), numeric(1)),
    k_bad_frac = vapply(results, function(r) as.numeric(r$k_frac_bad), numeric(1)),
    rhat_max = vapply(results, function(r) as.numeric(r$rhat_max), numeric(1)),
    ess_bulk_min = vapply(results, function(r) as.numeric(r$ess_bulk_min), numeric(1)),
    ess_tail_min = vapply(results, function(r) as.numeric(r$ess_tail_min), numeric(1)),
    n_divergences = vapply(results, function(r) as.numeric(r$n_divergences), numeric(1)),
    treedepth_max = vapply(results, function(r) as.numeric(r$treedepth_max), numeric(1)),
    stringsAsFactors = FALSE
  )
  
  res_tab <- data.frame(
    fit_id = param_grid$fit_id,
    k = param_grid$k,
    hs_tau0 = param_grid$hs_tau0,
    hs_slab_scale = param_grid$hs_slab_scale,
    hs_slab_df = param_grid$hs_slab_df,
    elpd = elpd_vals[, 1],
    elpd_se = elpd_vals[, 2],
    weight = as.numeric(w),
    stringsAsFactors = FALSE
  )
  
  res_tab <- merge(res_tab, diag_tab, by = "fit_id", all.x = TRUE)
  res_tab <- res_tab[order(-res_tab$elpd), ]
  rownames(res_tab) <- NULL
  
  if (verbose && nrow(res_tab) > 0) {
    best <- res_tab[1, ]
    message(sprintf(
      "Best: fit_id=%d, k=%d, hs_tau0=%.2f, elpd=%.2f (SE=%.2f)",
      best$fit_id, best$k, best$hs_tau0, best$elpd, best$elpd_se
    ))
  }
  
  list(
    fits = fits,
    loos = loos,
    weights = w,
    table = res_tab
  )
}
