
#define TOLERANCE 0.0000000001 // added tolerance for cast to int

#include <Rcpp.h>
#include <algorithm> 
#include <vector>

// [[Rcpp::depends(RcppEigen)]]
#include <RcppEigen.h>
#include <Eigen/SparseQR>

using namespace Rcpp;

int n;
double h;
double nh;         // n * h
int L;    // window length as.integer(nh)
double nh2;        // nh * nh
double L2;         // L * L
int Lp1;  // L + 1
double L2p1;       // Lp1 * Lp1

NumericVector y;
NumericVector cusumKernel;   // cusum(kernel(-L:L) / nh)
NumericVector Xty;            // transpose((I - S) %*% X) %*% y / n
NumericMatrix XtX;            // upper left corner of transpose((I - S) %*% X) %*% X %*% (I - S) / n
LogicalMatrix isComputedXtX;  // true if entry of XtX is computed already
NumericVector XtXgap;         // other entries of transpose((I - S) %*% X) %*% X %*% (I - S) / n
NumericMatrix ImSX;           // the first 2 * L columns of (I - S) %*% X
LogicalVector isComputedImSX; // true if column in ImSX is computed already

inline void computeCusumKernel() {
  int k = L;
  cusumKernel[0] = 1.0 - k * k / nh2;
  int i = 1;

  for ( ; i <= L; ++i) {
    --k;
    cusumKernel[i] = cusumKernel[i - 1] + 1.0 - k * k / nh2;
  }
  ++k;
  for (; k <= L; ++k) {
    cusumKernel[i] = cusumKernel[i - 1] + 1.0 - k * k / nh2;
    ++i;
  }
  return;
}

inline void computeImStR(const NumericVector &R, NumericVector ret) {
  NumericVector DY = NumericVector(n);
  int i = 0;

  for ( ; i <= L; ++i) {
    DY[i] = R[i] / cusumKernel[i + L];
  }

  for ( ; i < n - L; ++i) {
    DY[i] = R[i] / cusumKernel[2 * L];
  }

  int j = L;
  for ( ; i < n; ++i) {
    DY[i] = R[i] / cusumKernel[--j + L];
  }

  NumericVector ImStY = NumericVector(n);
  double cumsumConstant = 0.0;
  double cumsumLinear = 0.0;
  double cumsumQuadratic = 0.0;

  // initialize
  for (int i = 0; i <= L; ++i) {
    cumsumConstant += DY[i];
    cumsumLinear -= i * DY[i];
    cumsumQuadratic += i * i * DY[i];
  }

  ImStY[0] = (cumsumConstant - cumsumQuadratic / nh2 - R[0]) / n;

  i = 1;
  for ( ; i <= L; ++i) {
    cumsumQuadratic = cumsumQuadratic + 2.0 * cumsumLinear + cumsumConstant + L2 * DY[i + L];
    cumsumLinear = cumsumLinear + cumsumConstant - L * DY[i + L];
    cumsumConstant = cumsumConstant + DY[i + L];
    ImStY[i] = (cumsumConstant - cumsumQuadratic / nh2 - R[i]) / n;
  }

  for ( ; i < n - L; ++i) {
    cumsumQuadratic = cumsumQuadratic + 2.0 * cumsumLinear + cumsumConstant + L2 * DY[i + L] - L2p1 * DY[i - Lp1];
    cumsumLinear = cumsumLinear + cumsumConstant - L * DY[i + L] - Lp1 * DY[i - Lp1];
    cumsumConstant = cumsumConstant + DY[i + L] - DY[i - Lp1];
    ImStY[i] = (cumsumConstant - cumsumQuadratic / nh2 - R[i]) / n;
  }

  for ( ; i < n; ++i) {
    cumsumQuadratic = cumsumQuadratic + 2.0 * cumsumLinear + cumsumConstant - L2p1 * DY[i - Lp1];
    cumsumLinear = cumsumLinear + cumsumConstant - Lp1 * DY[i - Lp1];
    cumsumConstant = cumsumConstant - DY[i - Lp1];
    ImStY[i] = (cumsumConstant - cumsumQuadratic / nh2 - R[i]) / n;
  }

  // NumericVector test = NumericVector(n - 1);
  std::partial_sum(ImStY.begin(), ImStY.end() - 1, ret.begin());
  return;
}

inline void computeXty() {
  computeImStR(y, Xty);
  return;
}

inline void computeImSXj(int j) {
  int k = L - 1;
  for (int i = j; i >= std::max(j - L, 0); --i, --k) {
    ImSX(i, j) = -cusumKernel[k] / cusumKernel[2 * L - std::max(L - i, 0)];
  }
  k = 0;
  for (int i = j + 1; i <= std::min(j + L, n - 1); ++i, ++k) {
    ImSX(i, j) = 1 - cusumKernel[L + k] / cusumKernel[2 * L - std::max(L - i, 0)];
  }
  return;
}

inline NumericVector getImSXj(int j) {
  if (isComputedImSX[j]) {
    return ImSX(_, j);
  }
  
  computeImSXj(j);
  isComputedImSX[j] = true;
  return ImSX(_, j);
}

inline void computeXtXgap() {
  NumericVector maxVec = getImSXj(2 * L - 1);
  
  NumericVector ImStY(2 * L);
  double cumsumConstant = 0.0;
  double cumsumLinear = 0.0;
  double cumsumQuadratic = 0.0;
  
  cumsumConstant += maxVec[L];
  cumsumLinear -= L * maxVec[L];
  cumsumQuadratic += L2 * maxVec[L];
  
  ImStY[0] = ((cumsumConstant - cumsumQuadratic / nh2) / cusumKernel[2 * L]) / n;
  
  int i = 1;
  for ( ; i < L; ++i) {
    cumsumQuadratic = cumsumQuadratic + 2.0 * cumsumLinear + cumsumConstant + L2 * maxVec[i + L];
    cumsumLinear = cumsumLinear + cumsumConstant - L * maxVec[i + L];
    cumsumConstant = cumsumConstant + maxVec[i + L];
    ImStY[i] = ((cumsumConstant - cumsumQuadratic / nh2) / cusumKernel[2 * L]) / n;
  }
  
  for ( ; i < 2 * L; ++i) {
    cumsumQuadratic = cumsumQuadratic + 2.0 * cumsumLinear + cumsumConstant + L2 * maxVec[i + L];
    cumsumLinear = cumsumLinear + cumsumConstant - L * maxVec[i + L];
    cumsumConstant = cumsumConstant + maxVec[i + L];
    ImStY[i] = ((cumsumConstant - cumsumQuadratic / nh2) / cusumKernel[2 * L] - maxVec[i]) / n;
  }
  
  NumericVector partialXImStY(2 * L);
  std::partial_sum(ImStY.begin(), ImStY.end(), partialXImStY.begin());
  XtXgap = rev(partialXImStY);
  return;
}

inline double getXtX(int i, int j) {
  if (i > j) {
    int help = j;
    j = i;
    i = help;
  }
  
  if (j - i >= 2 * L) {
    return 0.0;
  }
  
  if (i > n / 2) {
    int help = n - j - 2;
    j = n - i - 2;
    i = help;
  }
  
  if (i >= 2 * L - 1) {
    return XtXgap[j - i];
  }
  
  if (isComputedXtX(i, j)) {
    return XtX(i, j);
  }
  
  // compute and store XtX(i, j)
  isComputedXtX(i, j) = true;
  double ret = 0.0;
  if (j > 2 * L - 1) {
    NumericVector iVec = getImSXj(i);
    
    int gap = j - 2 * L + 1;
    NumericVector maxVec = getImSXj(2 * L - 1);
    
    for (int k = j - L + 1; k <= i + L; ++k) {
      ret += iVec[k] * maxVec[k - gap];
    }
  } else {
    NumericVector iVec = getImSXj(i);
    NumericVector jVec = getImSXj(j);
    
    for (int k = std::max(j - L + 1, 0); k <= i + L; ++k) {
      ret += iVec[k] * jVec[k];
    }
  }
  
  XtX(i, j) = ret / n;
  
  return XtX(i, j);
}

inline double soft_thresh(double u, double t) {
  if (u > 0) {
    if (u > t) {
      return u-t;
    }
  } else {
    if (u < -t) {
      return u+t;
    }
  }
  return 0;
}

// thresh should be 1e-7 * null_dev 
// X and y centred
// Cols of X scaled to have l_2 norm 1
// Uses covariance updates
// Updates beta
// active set A is zero-indexed
int beta_active(std::vector<double>& betaActive, std::vector<std::vector<double> >& XtXactive, std::vector<double>& XtXdiag,
                IntegerVector const A, int const A_size, double const thresh, int const maxit, double lam) {
  double rel_err;
  int iter = 0;
  
  do {
    rel_err = 0;

    for (int k=0; k<A_size; k++) {
      int A_k = A[k];
      double X_ktR = Xty[A_k];
      
      std::vector<double>* XtXactiveK = &XtXactive[k];
      for (int j=0; j<A_size; j++) {
        X_ktR -= (*XtXactiveK)[j] * betaActive[j];
      }
      // Update beta
      double XtXkk = XtXdiag[k];

      double betaK = betaActive[k];
      double new_coef = soft_thresh(X_ktR + XtXkk * betaK, lam) / XtXkk;

      double diff = new_coef - betaK;
      double abs_diff = std::abs(betaK) - std::abs(new_coef);

      betaActive[k] = new_coef;

      // Update rel_err
      double obj_change = diff * (X_ktR - diff * XtXkk / 2) + lam * abs_diff;

      rel_err = std::max(obj_change, rel_err);
    }
    iter++;
  } while (rel_err > thresh && iter < maxit);
  return 0;
}

// 0 means no changes
// 1 or more means changes were made so beta_active needs to be run once more
// S, A  and R will always be disjoint (R is the "rest" of the predictors)
// Implements KKT check at the bottom of pg 20 of
// https://statweb.stanford.edu/~tibs/ftp/strong.pdf
int KKT_check(std::vector<double>& betaActive, std::vector<std::vector<double> >& XtXactive, std::vector<double>& XtXdiag,
              IntegerVector A, int * A_size, IntegerVector S, int * S_size, IntegerVector R, int * R_size, double const lam, double const lam_next) {
  int out=0;

  // Check for violations in strong set S
  for (int k=0; k<*S_size; k++) {
    double X_ktR = Xty[S[k]];

    for (int j=0; j<*A_size; j++) {
      X_ktR -= getXtX(A[j], S[k]) * betaActive[j];
    }
    // Add violations to the strong set S, and remove these from S
    if (std::abs(X_ktR) > lam) {
      betaActive.push_back(0.0);
      double helpkk = getXtX(S[k], S[k]);
      XtXdiag.push_back(helpkk);
      std::vector<double> XtXactiveNew;
      XtXactiveNew.reserve(n);
      for (int j = 0; j < *A_size; j++) {
        double help = getXtX(A[j], S[k]);
        XtXactiveNew.push_back(help);
        XtXactive[j].push_back(help);
      }
      XtXactiveNew.push_back(helpkk);
      XtXactive.push_back(XtXactiveNew);
      
      A[(*A_size)++] = S[k]; // Add elem to A
      S[k] = S[--(*S_size)]; // Remove kth elem of S
      out=1;
      --k;
    }
  }

  if (out==1) return out;

  double strong_thresh = 2 * lam_next - lam;
  // Check for violations among all predictors (there are no violations in S or A)
  for (int k=0; k < *R_size; k++) {
    double X_ktR = Xty[R[k]];
    // double test;
    for (int j=0; j<*A_size; j++) {
      X_ktR -= getXtX(A[j], R[k]) * betaActive[j];
      // test = getXtX(A[j], S[k]);
    }

    if (std::abs(X_ktR) > lam) {
      betaActive.push_back(0.0);
      double helpkk = getXtX(R[k], R[k]);
      XtXdiag.push_back(helpkk);
      std::vector<double> XtXactiveNew;
      XtXactiveNew.reserve(n);
      for (int j = 0; j < *A_size; j++) {
        double help = getXtX(A[j], R[k]);
        XtXactiveNew.push_back(help);
        XtXactive[j].push_back(help);
      }
      XtXactiveNew.push_back(helpkk);
      XtXactive.push_back(XtXactiveNew);
      
      A[(*A_size)++] = R[k]; // Add elem to A
      R[k] = R[--(*R_size)]; // Remove kth elem of R
      out=2;
      --k;
    } else if (std::abs(X_ktR) >= strong_thresh) {
      S[(*S_size)++] = R[k]; // Add elem to S
      R[k] = R[--(*R_size)]; // Remove kth elem of R
      --k;
    }
  }
  return out;
}

// [[Rcpp::export(name = ".lassoImSX")]]
NumericMatrix lassoImSX(NumericVector y2, double bandwidth, NumericVector cusumKernel2, NumericVector Xty2, NumericMatrix XtX2,
                         LogicalMatrix isComputedXtX2, NumericVector XtXgap2, NumericMatrix ImSX2, LogicalVector isComputedImSX2,
                         NumericVector const lambda, double const thresh, int const maxit) {
  n = y2.size();
  h = bandwidth;    // bandwidth
  nh = n * h;         // n * h
  L = nh + TOLERANCE;    // window length as.integer(nh)
  nh2 = nh * nh;        // nh * nh
  L2 = L * L;         // L * L
  Lp1 = L + 1;  // L + 1
  L2p1 = Lp1 * Lp1;       // Lp1 * Lp1

  y = y2;
  cusumKernel = cusumKernel2;    // cusum(kernel(-L:L) / nh)
  Xty = Xty2;            // transpose((I - S) %*% X) %*% y / n
  XtX = XtX2;            // upper left corner of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  isComputedXtX = isComputedXtX2;  // true if entry of XtX is computed already
  XtXgap = XtXgap2;         // other entries of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  ImSX = ImSX2;           // the first 2 * L columns of (I - S) %*% X
  isComputedImSX = isComputedImSX2; // true if column in ImSX is computed already

  int p = n - 1;
  IntegerVector A = IntegerVector(p);
  IntegerVector S = IntegerVector(p);
  IntegerVector R = IntegerVector(p);
  for (int i = 0; i < p; ++i) {
    R[i] = i;
  }
  int A_size = 0;
  int S_size = 0;
  int R_size = p;

  NumericMatrix beta = NumericMatrix(p, lambda.size());
  std::vector<double> betaActive;
  std::vector< std::vector<double> > XtXactive;
  std::vector<double> XtXdiag;
  betaActive.reserve(n);
  XtXactive.reserve(n);
  XtXdiag.reserve(n);

  for (int l = 0; l < lambda.size() - 1; l++) {
    // NumericMatrix::Column betacol = beta( _, l);
    do {
      beta_active(betaActive, XtXactive, XtXdiag, A, A_size, thresh, maxit, lambda[l]);
    } while (KKT_check(betaActive, XtXactive, XtXdiag, A, &A_size, S, &S_size, R, &R_size, lambda[l], lambda[l + 1]) > 0);
    // Copy active set to matrix beta
    for (int k = 0; k < A_size; k++) {
      beta(A[k], l) = betaActive[k];
    }
  }

  {
    // Brackets not strictly necessary
    int l = lambda.size() - 1;
    // NumericMatrix::Column betacol = beta( _, l);
    do {
      beta_active(betaActive, XtXactive, XtXdiag, A, A_size, thresh, maxit, lambda[l]);
    } while (KKT_check(betaActive, XtXactive, XtXdiag, A, &A_size, S, &S_size, R, &R_size, lambda[l], lambda[l]) > 0);
    // Copy active set to matrix beta
    for (int k = 0; k < A_size; k++) {
      beta(A[k], l) = betaActive[k];
    }
  }
  // no l+1 here as it will be beyond the length of lambda

  return beta;
}

// [[Rcpp::export(name = ".getXty")]]
NumericVector getXty(NumericVector y, double bandwidth, NumericVector cusumKernel2, NumericVector Xty2, NumericMatrix XtX2,
                     LogicalMatrix isComputedXtX2, NumericVector XtXgap2, NumericMatrix ImSX2, LogicalVector isComputedImSX2) {
  return Xty2;
}

// [[Rcpp::export(name = ".getXtX")]]
NumericMatrix getXtX(NumericVector y2, double bandwidth) {
  n = y2.size();
  h = bandwidth;    // bandwidth
  nh = n * h;         // n * h
  L = nh + TOLERANCE;    // window length as.integer(nh)
  nh2 = nh * nh;        // nh * nh
  L2 = L * L;         // L * L
  Lp1 = L + 1;  // L + 1
  L2p1 = Lp1 * Lp1;       // Lp1 * Lp1
  
  y = y2;
  cusumKernel = NumericVector(2 * L + 1);    // cusum(kernel(-L:L) / nh)
  computeCusumKernel();
  XtX = NumericMatrix(2 * L - 1, 4 * L - 2);            // upper left corner of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  isComputedXtX = LogicalMatrix(2 * L - 1, 4 * L - 2);  // true if entry of XtX is computed already
  XtXgap = NumericVector(2 * L);         // other entries of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  ImSX = NumericMatrix(3 * L, 2 * L);           // the first 2 * L columns of (I - S) %*% X
  isComputedImSX = LogicalVector(2 * L); // true if column in ImSX is computed already
  computeXtXgap();
  

  NumericMatrix ret(n - 1, n - 1);
  for (int i = 0; i < n - 1; ++i)
    for (int j = 0; j < n - 1; ++j)
      ret(i, j) = getXtX(i, j);
  return ret;
}

// [[Rcpp::export(name = ".initializeKernel")]]
void initializeKernel(NumericVector y2, double bandwidth, NumericVector cusumKernel2, NumericVector Xty2, NumericMatrix XtX2,
                      LogicalMatrix isComputedXtX2, NumericVector XtXgap2, NumericMatrix ImSX2, LogicalVector isComputedImSX2) {
  n = y2.size();
  h = bandwidth;    // bandwidth
  nh = n * h;         // n * h
  L = nh + TOLERANCE;    // window length as.integer(nh)
  nh2 = nh * nh;        // nh * nh
  L2 = L * L;         // L * L
  Lp1 = L + 1;  // L + 1
  L2p1 = Lp1 * Lp1;       // Lp1 * Lp1
  
  y = y2;
  cusumKernel = cusumKernel2;    // cusum(kernel(-L:L) / nh)
  Xty = Xty2;            // transpose((I - S) %*% X) %*% y / n
  XtX = XtX2;            // upper left corner of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  isComputedXtX = isComputedXtX2;  // true if entry of XtX is computed already
  XtXgap = XtXgap2;         // other entries of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  ImSX = ImSX2;           // the first 2 * L columns of (I - S) %*% X
  isComputedImSX = isComputedImSX2; // true if column in ImSX is computed already
  
  computeCusumKernel();
  computeXty();
  computeXtXgap();
  
  return;
}

// [[Rcpp::export(name = ".postProcessing")]]
Eigen::VectorXd postProcessing(IntegerVector const J, NumericVector y2, double bandwidth, NumericVector cusumKernel2, NumericVector Xty2, NumericMatrix XtX2,
                               LogicalMatrix isComputedXtX2, NumericVector XtXgap2, NumericMatrix ImSX2, LogicalVector isComputedImSX2) {
  n = y.size();
  h = bandwidth;    // bandwidth
  nh = n * h;         // n * h
  L = nh + TOLERANCE;    // window length as.integer(nh)
  nh2 = nh * nh;        // nh * nh
  L2 = L * L;         // L * L
  Lp1 = L + 1;  // L + 1
  L2p1 = Lp1 * Lp1;       // Lp1 * Lp1
  
  y = y2;
  cusumKernel = cusumKernel2;    // cusum(kernel(-L:L) / nh)
  Xty = Xty2;            // transpose((I - S) %*% X) %*% y / n
  XtX = XtX2;            // upper left corner of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  isComputedXtX = isComputedXtX2;  // true if entry of XtX is computed already
  XtXgap = XtXgap2;         // other entries of transpose((I - S) %*% X) %*% X %*% (I - S) / n
  ImSX = ImSX2;           // the first 2 * L columns of (I - S) %*% X
  isComputedImSX = isComputedImSX2; // true if column in ImSX is computed already
  
  typedef Eigen::Triplet<double> trip;
  std::list<trip> tripletList;
  Eigen::VectorXd XtyJ(J.size());
  double value;
  
  for (int k = 0; k < J.size(); ++k) {
    XtyJ[k] = Xty[J[k]];
    tripletList.push_back(trip(k, k, getXtX(J[k], J[k])));
    for (int j = 0; j < k; ++j) {
      value = getXtX(J[j], J[k]);
      tripletList.push_back(trip(j, k, value));
      tripletList.push_back(trip(k, j, value));
    }
  }
  Eigen::SparseMatrix<double> mat(J.size(), J.size());
  mat.setFromTriplets(tripletList.begin(), tripletList.end());
  
  Eigen::SparseQR <Eigen::SparseMatrix<double>, Eigen::COLAMDOrdering<int> > solver;
  solver.compute(mat);
  if(solver.info() != Eigen::Success) {
    stop("decomposition failed"); //TODO better error message
    // return Xty;
  }
  Eigen::VectorXd ret = solver.solve(XtyJ);
  if(solver.info() != Eigen::Success) {
    stop("solving failed"); //TODO better error message
    // return Xty;
  }
  
  return ret;
}
