// src/rcpp_interface.cpp
// Exports for 4-mode pixel morphing + LAP solvers

// [[Rcpp::plugins(cpp17)]]
#include <Rcpp.h>
#include <limits>
#include <string>
#include <algorithm>
#include "../core/lap_internal.h"
#include "../core/lap_utils_rcpp.h"
#include "../gabow_tarjan/utils_gabow_tarjan.h"

using namespace Rcpp;

// =======================
// Forward decls for greedy matching (implemented in solvers/greedy_matching.cpp)
extern Rcpp::List greedy_matching_sorted_impl(Rcpp::NumericMatrix cost_matrix, bool maximize);
extern Rcpp::List greedy_matching_row_best_impl(Rcpp::NumericMatrix cost_matrix, bool maximize);
extern Rcpp::List greedy_matching_pq_impl(Rcpp::NumericMatrix cost_matrix, bool maximize);
extern Rcpp::List greedy_matching_impl(Rcpp::NumericMatrix cost_matrix, bool maximize, std::string strategy);

// =======================
// Forward decls for LAP solvers (no [[Rcpp::export]] here)
Rcpp::List solve_cycle_cancel_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_gabow_tarjan_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_lapmod_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_bottleneck_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_csa_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_sinkhorn_impl(Rcpp::NumericMatrix cost, double lambda, double tol,
                               int max_iter, Rcpp::Nullable<Rcpp::NumericVector> r_weights,
                               Rcpp::Nullable<Rcpp::NumericVector> c_weights);
Rcpp::IntegerVector sinkhorn_round_impl(Rcpp::NumericMatrix P);
Rcpp::List solve_ramshaw_tarjan_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_push_relabel_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_jv_duals_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_network_simplex_rcpp(const Rcpp::NumericMatrix& cost_matrix);
// =======================
Rcpp::List prepare_cost_matrix_impl(NumericMatrix cost, bool maximize);
Rcpp::List solve_bruteforce_impl(NumericMatrix cost, bool maximize);
Rcpp::List solve_jv_impl(NumericMatrix cost, bool maximize);
Rcpp::List solve_murty_impl(Rcpp::NumericMatrix cost, int k, bool maximize, std::string single_method);
Rcpp::List solve_auction_impl(Rcpp::NumericMatrix cost, bool maximize, double eps_in);
Rcpp::List solve_auction_scaled_impl(Rcpp::NumericMatrix cost, bool maximize, std::string schedule);
Rcpp::List solve_auction_scaled_impl(Rcpp::NumericMatrix cost, bool maximize,
                                     double initial_epsilon_factor,
                                     double alpha,
                                     double final_epsilon);
Rcpp::List solve_auction_gauss_seidel_impl(Rcpp::NumericMatrix cost, bool maximize, double eps_in);
Rcpp::List solve_ssp_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_hungarian_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_csflow_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_kbest_lawler_impl(Rcpp::NumericMatrix cost, int k, std::string method_base, bool maximize);
Rcpp::List solve_hk01_impl(Rcpp::NumericMatrix cost, bool maximize);
Rcpp::List solve_line_metric_impl(const Rcpp::NumericVector& x,
                                  const Rcpp::NumericVector& y,
                                  const std::string& cost,
                                  bool maximize);
Rcpp::List solve_ssap_bucket_impl(Rcpp::NumericMatrix cost, bool maximize);

// =======================
// Pixel morphing core (implemented in morph_pixel_level.cpp)
// =======================
extern Rcpp::List analyze_color_overlap(const Rcpp::NumericVector& pixelsA,
                                        const Rcpp::NumericVector& pixelsB,
                                        int H, int W,
                                        int quantize_bits);

// REMOVED: extract_patches - No longer needed with square tiling implementation

// REMOVED: compute_color_match_assignment - R does the assignment!
// REMOVED: compute_color_walk_assignment - R does the assignment!

extern Rcpp::NumericMatrix compute_pixel_cost(const Rcpp::NumericVector& pixelsA,
                                              const Rcpp::NumericVector& pixelsB,
                                              int H, int W,
                                              double alpha, double beta);

extern Rcpp::NumericVector downscale_image(const Rcpp::NumericVector& pixels,
                                           int H, int W, int H_new, int W_new);

extern Rcpp::IntegerVector upscale_assignment(const Rcpp::IntegerVector& assignment,
                                              int H_orig, int W_orig,
                                              int H_scaled, int W_scaled);

extern Rcpp::List morph_pixel_level_impl(const Rcpp::NumericVector& pixelsA,
                                         const Rcpp::NumericVector& pixelsB,
                                         const Rcpp::IntegerVector& assignment,
                                         int H, int W,
                                         int n_frames);

extern Rcpp::List color_palette_info(const Rcpp::NumericVector& pixelsA,
                                     const Rcpp::NumericVector& pixelsB,
                                     int H, int W,
                                     int quantize_bits);

extern Rcpp::NumericMatrix spatial_cost_matrix(const Rcpp::IntegerVector& idxA,
                                               const Rcpp::IntegerVector& idxB,
                                               int H, int W);

// =======================
// LAP Solver Exports
// =======================

// [[Rcpp::export]]
Rcpp::List lap_prepare_cost_matrix(NumericMatrix cost, bool maximize) {
  return prepare_cost_matrix_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_bruteforce(NumericMatrix cost, bool maximize) {
  return solve_bruteforce_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_jv(NumericMatrix cost, bool maximize) {
  return solve_jv_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_kbest_murty(Rcpp::NumericMatrix cost, int k, bool maximize,
                           std::string single_method = "jv") {
  return solve_murty_impl(cost, k, maximize, single_method);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_auction(Rcpp::NumericMatrix cost, bool maximize,
                             Rcpp::Nullable<double> eps = R_NilValue) {
  double eps_in = std::numeric_limits<double>::quiet_NaN();
  if (eps.isNotNull()) eps_in = Rcpp::as<double>(eps.get());
  return solve_auction_impl(cost, maximize, eps_in);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_auction_scaled(Rcpp::NumericMatrix cost, bool maximize,
                                    std::string schedule = "alpha7") {
  std::transform(schedule.begin(), schedule.end(), schedule.begin(),
                 [](unsigned char c){ return static_cast<char>(std::tolower(c)); });
  if (schedule != "alpha7" && schedule != "pow2" && schedule != "halves") {
    LAP_ERROR("Invalid schedule: '%s'. Use: 'alpha7', 'pow2', 'halves'.", schedule.c_str());
  }
  return solve_auction_scaled_impl(cost, maximize, schedule);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_auction_scaled_params(Rcpp::NumericMatrix cost, bool maximize,
                                           double initial_epsilon_factor = 1.0,
                                           double alpha = 7.0,
                                           Rcpp::Nullable<double> final_epsilon = R_NilValue) {
  double fe = -1.0;
  if (final_epsilon.isNotNull()) fe = Rcpp::as<double>(final_epsilon.get());
  return solve_auction_scaled_impl(cost, maximize, initial_epsilon_factor, alpha, fe);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_auction_gs(Rcpp::NumericMatrix cost, bool maximize,
                                Rcpp::Nullable<double> eps = R_NilValue) {
  double eps_in = std::numeric_limits<double>::quiet_NaN();
  if (eps.isNotNull()) eps_in = Rcpp::as<double>(eps.get());
  return solve_auction_gauss_seidel_impl(cost, maximize, eps_in);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_ssp(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_ssp_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_hungarian(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_hungarian_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_csflow(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_csflow_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_kbest_lawler(Rcpp::NumericMatrix cost, int k,
                            std::string method_base = "jv", bool maximize = false) {
  return solve_kbest_lawler_impl(cost, k, method_base, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_hk01(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_hk01_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_line_metric_cpp(Rcpp::NumericVector x,
                                     Rcpp::NumericVector y,
                                     std::string cost = "L1",
                                     bool maximize = false) {
  return solve_line_metric_impl(x, y, cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_ssap_bucket(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_ssap_bucket_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_gabow_tarjan(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_gabow_tarjan_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_lapmod(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_lapmod_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_bottleneck(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_bottleneck_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_csa(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_csa_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_sinkhorn(Rcpp::NumericMatrix cost, double lambda = 10.0,
                              double tol = 1e-9, int max_iter = 1000,
                              Rcpp::Nullable<Rcpp::NumericVector> r_weights = R_NilValue,
                              Rcpp::Nullable<Rcpp::NumericVector> c_weights = R_NilValue) {
  return solve_sinkhorn_impl(cost, lambda, tol, max_iter, r_weights, c_weights);
}

// [[Rcpp::export]]
Rcpp::IntegerVector sinkhorn_round(Rcpp::NumericMatrix P) {
  return sinkhorn_round_impl(P);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_ramshaw_tarjan(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_ramshaw_tarjan_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_push_relabel(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_push_relabel_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_jv_duals(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_jv_duals_impl(cost, maximize);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_network_simplex(Rcpp::NumericMatrix cost) {
  return solve_network_simplex_rcpp(cost);
}

// =======================
// Greedy matching exports (implemented in solvers/greedy_matching.cpp)
// =======================

// [[Rcpp::export]]
Rcpp::List greedy_matching_sorted(Rcpp::NumericMatrix cost_matrix, bool maximize = false) {
  return greedy_matching_sorted_impl(cost_matrix, maximize);
}

// [[Rcpp::export]]
Rcpp::List greedy_matching_row_best(Rcpp::NumericMatrix cost_matrix, bool maximize = false) {
  return greedy_matching_row_best_impl(cost_matrix, maximize);
}

// [[Rcpp::export]]
Rcpp::List greedy_matching_pq(Rcpp::NumericMatrix cost_matrix, bool maximize = false) {
  return greedy_matching_pq_impl(cost_matrix, maximize);
}

// [[Rcpp::export]]
Rcpp::List greedy_matching(Rcpp::NumericMatrix cost_matrix, bool maximize = false,
                          std::string strategy = "row_best") {
  return greedy_matching_impl(cost_matrix, maximize, strategy);
}

// =======================
// Pixel morphing exports
// =======================

// [[Rcpp::export]]
Rcpp::List analyze_color_overlap_cpp(Rcpp::NumericVector pixelsA,
                                     Rcpp::NumericVector pixelsB,
                                     int H, int W,
                                     int quantize_bits = 5) {
  const int N = H * W;
  const int expected = N * 3;
  if (pixelsA.size() != expected || pixelsB.size() != expected)
    LAP_ERROR("pixelsA and pixelsB must be H*W*3.");
  return analyze_color_overlap(pixelsA, pixelsB, H, W, quantize_bits);
}

// REMOVED: extract_patches_cpp - No longer needed with square tiling implementation

// REMOVED: compute_color_match_assignment_cpp - R does the assignment!
// REMOVED: compute_color_walk_assignment_cpp - R does the assignment!

// [[Rcpp::export]]
Rcpp::NumericMatrix compute_pixel_cost_cpp(const Rcpp::NumericVector& pixelsA,
                                           const Rcpp::NumericVector& pixelsB,
                                           int H, int W,
                                           double alpha, double beta) {
  const int N = H * W;
  const int expected = N * 3;
  if (pixelsA.size() != expected || pixelsB.size() != expected)
    LAP_ERROR("pixelsA and pixelsB must be H*W*3.");
  return compute_pixel_cost(pixelsA, pixelsB, H, W, alpha, beta);
}

// [[Rcpp::export]]
Rcpp::NumericVector downscale_image_cpp(Rcpp::NumericVector pixels,
                                        int H, int W, int H_new, int W_new) {
  const int N = H * W;
  const int expected = N * 3;
  if (pixels.size() != expected)
    LAP_ERROR("pixels must be H*W*3.");
  return downscale_image(pixels, H, W, H_new, W_new);
}

// [[Rcpp::export]]
Rcpp::IntegerVector upscale_assignment_cpp(Rcpp::IntegerVector assignment,
                                           int H_orig, int W_orig,
                                           int H_scaled, int W_scaled) {
  const int N_scaled = H_scaled * W_scaled;
  if (assignment.size() != N_scaled)
    LAP_ERROR("assignment must have H_scaled*W_scaled elements.");
  return upscale_assignment(assignment, H_orig, W_orig, H_scaled, W_scaled);
}

// [[Rcpp::export]]
Rcpp::List morph_pixel_level_cpp(Rcpp::NumericVector pixelsA,
                                 Rcpp::NumericVector pixelsB,
                                 Rcpp::IntegerVector assignment,
                                 int H, int W,
                                 int n_frames) {
  const int N = H * W;
  const int expected = N * 3;
  if (pixelsA.size() != expected || pixelsB.size() != expected)
    LAP_ERROR("pixelsA and pixelsB must be H*W*3.");
  if (assignment.size() != N)
    LAP_ERROR("assignment must have H*W elements.");
  for (int i = 0; i < N; ++i) {
    if (assignment[i] < 0 || assignment[i] >= N) assignment[i] = i;
  }
  return morph_pixel_level_impl(pixelsA, pixelsB, assignment, H, W, n_frames);
}

// [[Rcpp::export]]
Rcpp::List color_palette_info_cpp(Rcpp::NumericVector pixelsA,
                                  Rcpp::NumericVector pixelsB,
                                  int H, int W,
                                  int quantize_bits = 5) {
  const int N = H * W;
  const int expected = N * 3;
  if (pixelsA.size() != expected || pixelsB.size() != expected)
    LAP_ERROR("pixelsA and pixelsB must be H*W*3.");
  return color_palette_info(pixelsA, pixelsB, H, W, quantize_bits);
}

// [[Rcpp::export]]
Rcpp::NumericMatrix spatial_cost_matrix_cpp(Rcpp::IntegerVector idxA,
                                            Rcpp::IntegerVector idxB,
                                            int H, int W) {
  return spatial_cost_matrix(idxA, idxB, H, W);
}

// [[Rcpp::export]]
Rcpp::List lap_solve_cycle_cancel(Rcpp::NumericMatrix cost, bool maximize) {
  return solve_cycle_cancel_impl(cost, maximize);
}


// =======================
// Orlin-Ahuja Algorithm Export (production only)
// =======================

// Forward declaration for production solver
extern Rcpp::List oa_solve_impl(Rcpp::NumericMatrix cost_r, bool maximize, double alpha, int auction_rounds);

// [[Rcpp::export]]
Rcpp::List oa_solve(Rcpp::NumericMatrix cost_r, double alpha = 5.0, int auction_rounds = 10) {
    return oa_solve_impl(cost_r, false, alpha, auction_rounds);
}

