#include "PsiFunction.h"
using namespace Rcpp;

#define DEBUG(STRING)                                          \
// /Rcpp::Rcout << STRING << std::endl;

template <typename T> int sgn(T val) {
  return (T(0) < val) - (val < T(0));
}

// class PsiFunction
PsiFunction::PsiFunction() {}

const std::string PsiFunction::name() const { 
  return "classic (x^2/2)"; 
}

const std::string PsiFunction::show() const {
  return this->name() + " psi function" + this->showDefaults();
}

void PsiFunction::chgDefaults(NumericVector tuningParameters) {}

NumericVector PsiFunction::tDefs() const {
  return NumericVector(0);
}

const std::string PsiFunction::showDefaults() const {
  return "";  
}

const double PsiFunction::rhoFun(const double x) {
  return x * x / 2.;
}

const double PsiFunction::psiFun(const double x) {
  return x;
}

const double PsiFunction::wgtFun(const double x) {
  return 1.;
}

const double PsiFunction::DpsiFun(const double x) {
  return 1.;
}

const double PsiFunction::DwgtFun(const double x) {
  return 0.;
}

const double PsiFunction::Erho() {
  return 0.5;
}

const double PsiFunction::Epsi2() {
  return 1.;
}

const double PsiFunction::EDpsi() {
  return 1.;
}

PsiFunction::~PsiFunction() {}

const double PsiFunction::psi2Fun(const double x) {
  double value = this->psiFun(x);
  return value * value;
}

// end class PsiFuncion

// class PsiFunctionNumIntExp 
PsiFunctionNumIntExp::PsiFunctionNumIntExp() : 
  PsiFunction(), integration_(*(new DqagIntegration())) {
  reset();
}

const std::string PsiFunctionNumIntExp::name() const {
  return "PsiFunction with expectations computed using numerical integration";
}

void PsiFunctionNumIntExp::chgDefaults(NumericVector tuningParameters) {
  reset();
  PsiFunction::chgDefaults(tuningParameters);
}

const double PsiFunctionNumIntExp::Erho() {
  if (NumericVector::is_na(Erho_))
    Erho_ = computeErho();
  return Erho_;
}

const double PsiFunctionNumIntExp::Epsi2() {
  if (NumericVector::is_na(Epsi2_))
    Epsi2_ = computeEpsi2();
  return Epsi2_;
}

const double PsiFunctionNumIntExp::EDpsi() {
  if (NumericVector::is_na(EDpsi_))
    EDpsi_ = computeEDpsi();
  return EDpsi_;
}

PsiFunctionNumIntExp::~PsiFunctionNumIntExp() {
  delete &integration_;
}

void PsiFunctionNumIntExp::reset() {
  Erho_ = NA_REAL;
  Epsi2_ = NA_REAL;
  EDpsi_ = NA_REAL;
}

const double PsiFunctionNumIntExp::computeErho() {
  DEBUG("Called computeErho()")
  return integrate(&PsiFunction::rhoFun);
}

const double PsiFunctionNumIntExp::computeEpsi2() {
  return integrate(&PsiFunction::psi2Fun);
}

const double PsiFunctionNumIntExp::computeEDpsi() {
  return integrate(&PsiFunction::DpsiFun);
}

double PsiFunctionNumIntExp::integrate(Fptr fptr) {
  const void *exc[2] = {this, &fptr};
  void **ex = const_cast<void**>(exc);
  return integration_.ninfInf(psiFunctionIntegrandNorm, ex);
}

// end class PsiFunctionNumIntExp

// class PsiFunctionPropII
PsiFunctionPropII::PsiFunctionPropII() : 
  PsiFunctionNumIntExp(), base_(new SmoothPsi()), integration_(*(new DqagIntegration())) {
  // Rcpp::Rcout << "illegal contstructor called!!" << std::endl;
}

PsiFunctionPropII::PsiFunctionPropII(PsiFunction* base) : 
  PsiFunctionNumIntExp(), base_(base), integration_(*(new DqagIntegration())) 
{}

PsiFunctionPropII::~PsiFunctionPropII() {
  delete &integration_;
}

const std::string PsiFunctionPropII::name() const {
  return base_->name() + ", Proposal II";
}

void PsiFunctionPropII::chgDefaults(NumericVector x) {
  base_->chgDefaults(x);
  PsiFunctionNumIntExp::chgDefaults(x);
}

NumericVector PsiFunctionPropII::tDefs() const {
  return base_->tDefs();
}

const double PsiFunctionPropII::rhoFun(const double x) {
  if (!R_FINITE(x))
    return x;
  return integrate(&PsiFunction::psiFun, x);
}

const double PsiFunctionPropII::psiFun(const double x) {
  return base_->wgtFun(x) * base_->psiFun(x);
}

const double PsiFunctionPropII::wgtFun(const double x) {
  double value = base_->wgtFun(x);
  return value * value;
}

const double PsiFunctionPropII::DpsiFun(const double x) {
  return base_->wgtFun(x) * base_->DpsiFun(x) + 
    base_->DwgtFun(x) * base_->psiFun(x);
}

const double PsiFunctionPropII::DwgtFun(const double x) {
  return 2. * base_->wgtFun(x) * base_->DwgtFun(x);
}

const PsiFunction* PsiFunctionPropII::base() const {
  return base_;
}

const std::string PsiFunctionPropII::showDefaults() const {
  return base_->showDefaults();
}

double PsiFunctionPropII::integrate(Fptr fptr, double b) {
  const void *exc[2] = {this, &fptr};
  void **ex = const_cast<void**>(exc);
  double a = 0.;
  return integration_.aB(psiFunctionIntegrand, ex, &a, &b);
}

// end class PsiFunctionPropII

// class HuberPsi
HuberPsi::HuberPsi() : PsiFunction() {
  chgDefaults(NumericVector(0));
}

HuberPsi::HuberPsi(NumericVector k) : PsiFunction() {
  chgDefaults(k); 
}

const std::string HuberPsi::name() const {
  return "Huber";
}

void HuberPsi::chgDefaults(NumericVector k) {
  if (k.size() >= 1) {
    k_ = k[0];
  } else {
    k_ = 1.345;
  }
}

NumericVector HuberPsi::tDefs() const {
  NumericVector tDefs = NumericVector(1);
  tDefs[0] = k_;
  tDefs.names() = CharacterVector::create("k");
  return tDefs;
}

const std::string HuberPsi::showDefaults() const {
  char buffer[20];
  std::sprintf(buffer, " (k = %.5g)", k_);  
  return std::string(buffer);
}

const double HuberPsi::rhoFun(const double x) {
  double u = std::abs(x);
  if (u > k_) {
    return k_*(u - k_ / 2.);
  } else {
    return u * u / 2.;
  }
}

const double HuberPsi::psiFun(const double x) {
  if (x < -k_) {
    return -k_;
  } else if (x > k_) {
    return k_;
  } else {
    return x;
  }
}

const double HuberPsi::wgtFun(const double x) {
  if (x < -k_ || x > k_) {
    return k_ / std::abs(x);
  } else {
    return 1.;
  }
}

const double HuberPsi::DpsiFun(const double x) {
  if (x < -k_ || x > k_) {
    return 0.;
  } else {
    return 1.;
  }
}

const double HuberPsi::DwgtFun(const double x) {
  if (x < -k_) {
    return k_ / (x * x);
  } else if (x > k_) {
    return -k_ / (x * x);
  } else {
    return 0.;
  }
}

const double HuberPsi::Erho() {
  double iP = stats::pnorm_0(k_, 0, 0);
  return .5 - iP + k_ * (stats::dnorm_0(k_, 0) - k_ * iP);
}

const double HuberPsi::Epsi2() {
  if (k_ < 10.) {
    return 1. - 2.*(k_ * stats::dnorm_0(k_, 0) + (1. - k_*k_) * stats::pnorm_0(k_, 0, 0));
  } else {
    return 1.;
  }
}

const double HuberPsi::EDpsi() {
  return 2. * stats::pnorm_0(k_, 1, 0) - 1.;
}

HuberPsi::~HuberPsi() {}

// end class HuberPsi

// class SmoothPsi
SmoothPsi::SmoothPsi() : PsiFunctionNumIntExp() {
  chgDefaults(NumericVector(0));
}

SmoothPsi::SmoothPsi(NumericVector tuningParameters) : PsiFunctionNumIntExp() {
  chgDefaults(tuningParameters);
}

const std::string SmoothPsi::name() const {
  return "smoothed Huber";
}

void SmoothPsi::chgDefaults(NumericVector tuningParameters) {
  PsiFunctionNumIntExp::chgDefaults(tuningParameters);
  if (tuningParameters.size() >= 1) {
    k_ = tuningParameters[0];
  } else {
    k_ = 1.345;
  }
  if (tuningParameters.size() >= 2) {
    s_ = tuningParameters[1];
  } else {
    s_ = 10.;
  }
  a_ = std::pow(s_, 1. / (s_ + 1.));
  c_ = k_ - std::pow(a_, -s_);
  d_ = c_ - a_;
}

NumericVector SmoothPsi::tDefs() const {
  NumericVector tDefs = NumericVector(2);
  tDefs[0] = k_;
  tDefs[1] = s_;
  tDefs.names() = CharacterVector::create("k", "s");
  return tDefs;
}

const double SmoothPsi::rhoFun(const double x) {
  double ax = std::abs(x);
  if (ax <= c_) {
    return x * x / 2.;
  } else {
    return c_ * c_ / 2. + k_ * (ax - c_) -
      (std::pow(ax - d_, 1. - s_) - std::pow(a_, 1. - s_)) / (1. - s_);
  }
}

const double SmoothPsi::psiFun(const double x) {
  double ax = std::abs(x);
  if (ax <= c_) {
    return x;
  } else {
    return sgn(x) * (k_ - std::pow(ax - d_, -s_));
  }
}

const double SmoothPsi::DpsiFun(const double x) {
  double ax = std::abs(x);
  if (ax <= c_) {
    return 1.;
  } else {
    return s_ * std::pow(ax - d_, -s_ - 1.);
  }
}

const double SmoothPsi::wgtFun(const double x) {
  double ax = std::abs(x);
  if (ax <= c_) {
    return 1.;
  } else {
    return (k_ - std::pow(ax - d_, -s_)) / ax;
  }
}

const double SmoothPsi::DwgtFun(const double x) {
  double ax = std::abs(x);
  if (ax <= c_) {
    return 0.;
  } else {
    return std::pow(ax - d_, -s_ - 1.) * s_ / x - 
      (k_ - std::pow(ax - d_, -s_)) / (x * ax);
  }
}

SmoothPsi::~SmoothPsi() {}

const std::string SmoothPsi::showDefaults() const {
  char buffer[30];
  std::sprintf(buffer, " (k = %.5g, s = %.5g)", k_, s_);  
  return std::string(buffer);
}

// end class SmoothPsi

/*
 TODO: lqq, bisquare, etc, psi functions
 
lqqPsi <- psiFuncCached(rho = function(x, cc) Mpsi(x, cc, "lqq", -1),
  psi = function(x, cc) Mpsi(x, cc, "lqq", 0),
  Dpsi = function(x, cc) Mpsi(x, cc, "lqq", 1),
  wgt = function(x, cc) Mwgt(x, cc, "lqq"),
  Dwgt = function(x, cc) 
  (Mpsi(x, cc, "lqq", 1) - Mwgt(x, cc, "lqq"))/x,
  name = "lqq",
  cc = c(-0.5, 1.5, 0.95, NA))

bisquarePsi <- psiFuncCached(rho = function(x, k) Mpsi(x, k, "biweight", -1),
  psi = function(x, k) Mpsi(x, k, "biweight", 0),
  Dpsi = function(x, k) Mpsi(x, k, "biweight", 1),
  wgt = function(x, k) (1 - (x/k)^2)^2*(abs(x) <= k),
  Dwgt = function(x, k) (-(4*(1-(x/k)^2))*x/k^2)*(abs(x) <= k),                                                                                                   name = "bisquare",
  k = 4.68)
 */

std::string name(PsiFunction* p) {
  return p->name(); 
}

void chgDefaults(PsiFunction* p, NumericVector x) {
  return p->chgDefaults(x); 
}

NumericVector compute(PsiFunction* p, Fptr fptr, NumericVector x) {
  NumericVector result(x.size());
  for (int i = 0; i < x.size(); i++) {
    result[i] = (p->*fptr)(x[i]);
  }
  return result;
}

NumericVector rho(PsiFunction* p, NumericVector x) {
  return compute(p, &PsiFunction::rhoFun, x);
}

NumericVector psi(PsiFunction* p, NumericVector x) {
  return compute(p, &PsiFunction::psiFun, x);
}

NumericVector wgt(PsiFunction* p, NumericVector x) {
  return compute(p, &PsiFunction::wgtFun, x);
}

NumericVector Dpsi(PsiFunction* p, NumericVector x) {
  return compute(p, &PsiFunction::DpsiFun, x);
}

NumericVector Dwgt(PsiFunction* p, NumericVector x) {
  return compute(p, &PsiFunction::DwgtFun, x);
}

const double Erho(PsiFunction* p) {
  return p->Erho();
}

const double Epsi2(PsiFunction* p) {
  return p->Epsi2();
}

const double EDpsi(PsiFunction* p) {
  return p->EDpsi();
}

NumericVector tDefs(PsiFunction* p) {
  return p->tDefs();
}

void psiFunctionIntegrand(double *x, const int n, void *const ex) {
  PsiFunction **pp = static_cast<PsiFunction**>(ex);
  PsiFunction *p = pp[0];
  
  const Fptr **const pfptr  = static_cast<const Fptr **const>(ex);
  const Fptr *fptr = pfptr[1];
  
  for (int i = 0; i < n; i++) {
    double value = x[i];  
    x[i] = (p->**fptr)(value);
  }
  
  return;
}

void psiFunctionIntegrandNorm(double *x, const int n, void *const ex) {
  PsiFunction **pp = static_cast<PsiFunction**>(ex);
  PsiFunction *p = pp[0];
  
  const Fptr **const pfptr  = static_cast<const Fptr **const>(ex);
  const Fptr *fptr = pfptr[1];
  
  for (int i = 0; i < n; i++) {
    double value = x[i];  
    x[i] = (p->**fptr)(value) * stats::dnorm_0(value, 0);
  }
  
  return;
}


RCPP_EXPOSED_CLASS(PsiFunction)
RCPP_MODULE(psi_function_module) {
  
  class_<PsiFunction>("PsiFunction")
  .constructor()
  .method("name", &name)
  .method("chgDefaults", &chgDefaults)
  .method("rho", &rho)
  .method("psi", &psi)
  .method("wgt", &wgt)
  .method("Dpsi", &Dpsi)
  .method("Dwgt", &Dwgt)
  .method("Epsi2", &Epsi2)
  .method("EDpsi", &EDpsi)
  .method("Erho", &Erho)
  .method("show", &PsiFunction::show)
  .method("tDefs", &tDefs)
  ;
  
  class_<HuberPsi>("HuberPsi")
    .derives<PsiFunction>("PsiFunction")
    .constructor()
    .constructor<NumericVector>()
  ;
  
  class_<SmoothPsi>("SmoothPsi")
    .derives<PsiFunction>("PsiFunction")
    .constructor()
    .constructor<NumericVector>()
  ;
  
  class_<PsiFunctionPropII>("PsiFunctionToPropIIPsiFunctionWrapper")
    .derives<PsiFunction>("PsiFunction")
    .constructor()
    .constructor<PsiFunction*>()
    .method("base", &PsiFunctionPropII::base)
  ;
}
