#include "scalablebayesm.h"
#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;


//[[Rcpp::export]]
Rcpp::List rhierMnlRwMixtureParallel_rcpp_loop(Rcpp::List const& lgtdata,
                                               arma::mat const& Z,
                                               arma::vec const& deltabar, 
                                               arma::mat const& Ad, 
                                               arma::mat const& mubar, 
                                               arma::mat const& Amu, 
                                               int const& nu, 
                                               arma::mat const& V, 
                                               double s, 
                                               int R,
                                               int keep, 
                                               int nprint,
                                               bool drawdelta, 
                                               arma::mat olddelta,
                                               arma::vec const& a, 
                                               arma::vec oldprob, 
                                               arma::mat oldbetas,
                                               arma::vec ind,
                                               bool verbose)
{
  
  // Wayne Taylor 10/01/2014
  // Rico Bumbaca 10/10/2019
  // Kevin Nguyen 08/2022
  // Boyang Yu 06/2023
  
  int nlgt = lgtdata.size();// num of cross sectional units 
  int nvar = V.n_cols;//assigns the value of the number of columns in the V. 
  int nz = Z.n_cols;//assigns the value of the number of columns in the Z
  
  arma::mat rootpi, betabar, ucholinv, incroot;//This line declares and initializes four matrices named rootpi, betabar, ucholinv, incroot
  int mkeep;//This line declares an integer variable named mkeep.
  mnlMetropOnceOut metropout_struct;//This line declares a variable named metropout_struct of a custom user defined type mnlMetropOnceOut.
  Rcpp::List lgtdatai, nmix;//This line declares and initializes two list variables named lgtdatai and nmix.
  
  // convert List to std::vector of struct
  std::vector<moments> lgtdata_vector;
  /*
   1.std::vector is a standard library container class that represents a dynamic array-like structure.
   2.<moments> specifies the type of elements that the vector will hold. In this case, it's moments.
   3.lgtdata_vector is a variable of type std::vector<moments>, representing a vector of moments objects.
   4. This line declares an empty vector named lgtdata_vector that can store moments objects.
   */
  moments lgtdatai_struct;
  /*moments is a custom user-defined type or struct
   * create a variable called 'lgtdatai_struct' with type moment
   */
  //whole idea is to put 'lgtdatai_struct' into a vector called 'lgtdata_vector'
  for (int lgt = 0; lgt<nlgt; lgt++){//c++ start from 0. This is different with R
    lgtdatai = lgtdata[lgt];
    
    lgtdatai_struct.y = as<arma::vec>(lgtdatai["y"]);//coverts the 'y' element of the 'lgtdatai' object to an 'arma:vec' object using an "as" function
    lgtdatai_struct.X = as<arma::mat>(lgtdatai["X"]);//coverts the 'X' element of the 'lgtdatai' object to an 'arma:vec' object using an "as" function
    lgtdatai_struct.hess = as<arma::mat>(lgtdatai["hess"]);
    lgtdata_vector.push_back(lgtdatai_struct);//lgtdata_vector.push_back(lgtdatai_struct) appends the lgtdatai_struct object to the end of the lgtdata_vector vector, increasing its size by one.
  }

  
  // allocate space for draws
  arma::vec oldll = arma::zeros<arma::vec>(nlgt);//creates a arma::vec called 'oldll' and initialize by 0. The length is determined by nlgt
  arma::mat probdraw(floor((R)/keep), oldprob.size());//creates a matrix called probdraw. with number of rows equal to the floor of ((R)/keep), the number of column equals to the size of the "oldprob" vector #similar to pvec
  Rcpp::List compdraw(floor((R)/keep)); //create a Rcpp list object with the number of elements calculated as the floor of "R/keep". (creates floor((R)/keep)'s components in a list)
  
  mat Deltadraw(1,1); if(drawdelta) Deltadraw.zeros(floor((R)/keep), nz*nvar);//enlarge Deltadraw only if the space is required (Edited)
  
  if ((nprint>0) && verbose) startMcmcTimer();
  
  for (int rep = 0; rep<R; rep++){
    // first draw comps,ind,p | {beta_i}, delta
    // ind,p need initialization comps is drawn first in sub-Gibbs
    Rcpp::List mgout;
    if(drawdelta) {
      olddelta.reshape(nvar,nz);
      mgout = rmixGibbs1(oldbetas-Z*trans(olddelta),mubar,Amu,nu,V,a,oldprob,ind);//shift back the beta to find delta since we assume an initial beta
    //
    } else {
      mgout = rmixGibbs1(oldbetas,mubar,Amu,nu,V,a,oldprob,ind);
    }//the results gives p:(a mixture of probabilites) (prob, "pi" in formula), z:indicator of each component (choose a component (we have 3 compnents in total in this example,which to choose)), and comps: comps (a new draw of normal component, mu and var)
    //these are included in prior (prepare for theta)
    
    Rcpp::List oldcomp = mgout["comps"];
    oldprob = as<arma::vec>(mgout["p"]); 
    ind = as<arma::vec>(mgout["z"]);
    
    //now draw delta | {beta_i}, ind, comps
    if(drawdelta) olddelta = drawDelta1(Z,oldbetas,ind,oldcomp,deltabar,Ad);

    
    //loop over all LGT equations drawing beta_i | ind[i],z[i,],mu[ind[i]],rooti[ind[i]]
    for(int lgt = 0; lgt<nlgt; lgt++){
      Rcpp::List oldcomplgt = oldcomp[ind[lgt]-1];//choose specific component
      rootpi = as<arma::mat>(oldcomplgt[1]);//draws for the cholesky root of sigma 
      
      //note: beta_i = Delta*z_i + u_i  Delta is nvar x nz
      if(drawdelta){
        olddelta.reshape(nvar,nz);
        betabar = as<arma::vec>(oldcomplgt[0])+olddelta*vectorise(Z(lgt,arma::span::all));
      } else {
        betabar = as<arma::vec>(oldcomplgt[0]);
      }
      
      if (rep == 0) oldll[lgt] = llmnl(vectorise(oldbetas(lgt,arma::span::all)),lgtdata_vector[lgt].y,lgtdata_vector[lgt].X);

      //compute inc.root
      ucholinv = solve(trimatu(chol(lgtdata_vector[lgt].hess+rootpi*trans(rootpi))), arma::eye(nvar,nvar)); 
      incroot = chol(ucholinv*trans(ucholinv));
      
      metropout_struct = mnlMetropOnce(lgtdata_vector[lgt].y,lgtdata_vector[lgt].X,vectorise(oldbetas(lgt,arma::span::all)),
                                       oldll[lgt],s,incroot,betabar,rootpi);
      
      oldbetas(lgt,arma::span::all) = trans(metropout_struct.betadraw);
      oldll[lgt] = metropout_struct.oldll;  
    }
    
    //print time to completion and draw # every nprint'th draw
    if ((nprint>0) && verbose) if ((rep+1)%nprint==0) infoMcmcTimer(rep, R);
    
    if(((rep+1)>0) && ((rep+1)%keep==0)){
      mkeep = (rep+1)/keep;
      probdraw(mkeep-1, arma::span::all) = trans(oldprob);
      compdraw[mkeep-1] = oldcomp;
      if(drawdelta) Deltadraw(mkeep-1, span::all) = trans(vectorise(olddelta));//added
    }
  }
  if ((nprint>0) && verbose) endMcmcTimer();
//  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
//                            Rcpp::Named("probdraw")= probdraw));//check the old bayesm file (rhierMnlMixture) because we need to return delta...
//}

if (drawdelta){
  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                            Rcpp::Named("probdraw")= probdraw,
                            Rcpp::Named("Deltadraw")= Deltadraw));//added
}else{
  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                            Rcpp::Named("probdraw")= probdraw));
} 
}  


