###############################################################################
## Estimate plug-in unconstrained bandwidth for psi_0 (for 2-sample test)
##
## Returns
## Plug-in bandwidth
###############################################################################

Hpi.kfe <- function(x, nstage=2, Hstart, deriv.order=0, binned=FALSE, bgridsize, double.loop=FALSE, amise=FALSE, verbose=FALSE)
{
  if (deriv.order!=0) stop("Currently only deriv.order=0 is implemented")
   
  n <- nrow(x)
  d <- ncol(x)

  ## use normal reference bandwidth as initial condition 
  if (missing(Hstart)) { r <- 4; Hstart <- 2*(2/(n*(d+r)))^(2/(d+r+2)) * var(x) }
  Hstart <- matrix.sqrt(Hstart)
  D2K0 <- t(dmvnorm.deriv(x=rep(0,d), mu=rep(0,d), Sigma=diag(d), deriv.order=2))
  K0 <- dmvnorm.deriv(x=rep(0,d), mu=rep(0,d), Sigma=diag(d), deriv.order=0)

  if (nstage==2)
  {  
    ## stage 1
    psi4.ns <- psins(r=4, Sigma=var(x), deriv.vec=TRUE)
    amse2.temp <- function(vechH)
    { 
      H <- invvech(vechH) %*% invvech(vechH)
      Hinv <- chol2inv(chol(H))
      Hinv12 <- matrix.sqrt(Hinv)
      amse2.temp <- 1/(det(H)^(1/2)*n)*((Hinv12 %x% Hinv12) %*% D2K0) + 1/2* t(vec(H) %x% diag(d^2)) %*% psi4.ns
      return(sum((amse2.temp)^2)) 
    }
    result <- optim(vech(Hstart), amse2.temp, method="BFGS")
    
    H2 <- invvech(result$par) %*% invvech(result$par)
    psi2.hat <- kfe(x=x, G=H2, deriv.order=2, double.loop=double.loop, add.index=FALSE, binned=binned, verbose=verbose)
  }
  else
    psi2.hat <- psins(r=2, Sigma=var(x), deriv.vec=TRUE)

  ## stage 2
  amse.temp <- function(vechH)
  { 
    H <- invvech(vechH) %*% invvech(vechH)
    amse.temp <- 1/(det(H)^(1/2)*n)*K0 + 1/2* t(vec(H)) %*% psi2.hat
    return(sum((amse.temp^2))) 
  }
  r <- 2; Hstart <- 2*(2/(n*(d+r)))^(2/(d+r+2)) * var(x)
  Hstart <- matrix.sqrt(Hstart)
  result <- optim(vech(Hstart), amse.temp, method="BFGS", control=list(trace=as.numeric(verbose)))
  H <- invvech(result$par) %*% invvech(result$par)
  
  if (!amise) return(H)
  else return(list(H = H, PI=result$value))
}


Hpi.diag.kfe <- function(x, nstage=2, Hstart, deriv.order=0, binned=FALSE, double.loop=FALSE, amise=FALSE, verbose=FALSE)
{
  if (deriv.order!=0) stop("Currently only deriv.order=0 is implemented")

  n <- nrow(x)
  d <- ncol(x)

  ## use normal reference bandwidth as initial condition 
  if (missing(Hstart)) { r <- 4; Hstart <- 2*(2/(n*(d+r)))^(2/(d+r+2)) * var(x) }
  Hstart <- matrix.sqrt(Hstart)
  D2K0 <- t(dmvnorm.deriv(x=rep(0,d), mu=rep(0,d), Sigma=diag(d), deriv.order=2))
  K0 <- dmvnorm.deriv(x=rep(0,d), mu=rep(0,d), Sigma=diag(d), deriv.order=0)
  ##m2K2 <- 1/2*(4*pi)^(-d/2)*vec(diag(d))

  if (nstage==2)
  {  
    ## stage 1
    psi4.ns <- psins(r=4, Sigma=var(x), deriv.vec=TRUE)
    
    amse2.temp <- function(diagH)
    { 
      H <- diag(diagH) %*% diag(diagH)
      Hinv <- chol2inv(chol(H))
      Hinv12 <- matrix.sqrt(Hinv)
      amse2.temp <- 1/(det(H)^(1/2)*n)*((Hinv12 %x% Hinv12) %*% D2K0) + 1/2* t(vec(H) %x% diag(d^2)) %*% psi4.ns
      return(sum((amse2.temp)^2)) 
    }
    result <- optim(diag(Hstart), amse2.temp, method="BFGS")
    H2 <- diag(result$par) %*% diag(result$par)

    psi2.hat <- kfe(x=x, G=H2, deriv.order=2, double.loop=double.loop, add.index=FALSE, binned=binned, verbose=verbose) 
  }
  else
    psi2.hat <- psins(r=2, Sigma=var(x), deriv.vec=TRUE)
 
  ## stage 2
  amse.temp <- function(diagH)
  { 
    H <- diag(diagH) %*% diag(diagH)
    amse.temp <- 1/(det(H)^(1/2)*n)*K0 + 1/2* t(vec(H)) %*% psi2.hat
    return(sum((amse.temp^2))) 
  }
  r <- 2; Hstart <- 2*(2/(n*(d+r)))^(2/(d+r+2)) * var(x)
  Hstart <- matrix.sqrt(Hstart)
  result <- optim(diag(Hstart), amse.temp, method="BFGS")
  H <- diag(result$par) %*% diag(result$par)
  
  if (!amise) return(H)
  else return(list(H = H, PI=result$value))
}


#######################################################################################################
## Test statistic for multivariate 2-sample test
#######################################################################################################

kde.test <- function(x1, x2, H1, H2, psi1, psi2, fhat1, fhat2, var.fhat1, var.fhat2, double.loop=FALSE, binned=FALSE, bgridsize, verbose=FALSE, pre.scale=FALSE)
{
  n1 <- nrow(x1)
  n2 <- nrow(x2)
  d <- ncol(x1)
  K0 <- drop(dmvnorm.deriv(x=rep(0,d), mu=rep(0,d), Sigma=diag(d), deriv.order=0))

  if (pre.scale)
  {
    x12.star <- pre.scale(rbind(x1,x2))
    x1 <- x12.star[1:n1,]
    x2 <- x12.star[(n1+1):(n1+n2),]
  }

  ## kernel estimation for components of test statistic
  if (missing(H1)) H1 <- Hpi.kfe(x1, nstage=2, double.loop=double.loop, deriv.order=0, binned=binned, bgridsize=bgridsize, verbose=verbose)
  if (missing(H2)) H2 <- Hpi.kfe(x2, nstage=2, double.loop=double.loop, deriv.order=0, binned=binned, bgridsize=bgridsize, verbose=verbose)

  symm <- FALSE ## don't use symmetriser matrices in psi functional calculations
  if (missing(psi1)) psi1 <- eta.kfe.y(x=x1, y=x1, G=H1, verbose=verbose, symm=symm)      
  if (missing(psi2)) psi2 <- eta.kfe.y(x=x2, y=x2, G=H2, verbose=verbose, symm=symm)

  if (!missing(fhat1))
  {
    fhat1 <- find.nearest.gridpts(x=rbind(x1,x2), gridx=fhat1$eval.points, f=fhat1$estimate)$fx
    psi12 <- sum(tail(fhat1, n=n2))/n2
    var.fhat1 <- var(head(fhat1, n=n1))
  }
  else
  {  
    if (missing(var.fhat1))
    {
      S1 <- var(x1)
      H1.r1 <- Hamise.mixt(mus=rep(0,d), Sigmas=S1, samp=n1, props=1, deriv.order=1)
      if (binned)
      {
        fhat1.r1.est <- kdde(x=x1, H=H1.r1, deriv.order=1, binned=TRUE)
        fhat1.r1 <- matrix(0, nrow=1, ncol=d)
        for (i in 1:d)
          fhat1.r1[,i] <- find.nearest.gridpts(x=apply(x1,2, mean), gridx=fhat1.r1.est$eval.points, f=fhat1.r1.est$estimate[[i]])$fx
      }
      else
        fhat1.r1 <- kdde(x=x1, H=H1.r1, deriv.order=1, eval.points=apply(x1, 2, mean))$estimate
      var.fhat1 <- drop(fhat1.r1 %*% S1 %*% t(fhat1.r1))
    }
    psi12 <- eta.kfe.y(x=x1, G=H1, y=x2, verbose=verbose, symm=symm) 
  }

  if (!missing(fhat2))
  {
    fhat2 <- find.nearest.gridpts(x=rbind(x1,x2), gridx=fhat2$eval.points, f=fhat2$estimate)$fx
    psi21 <- sum(head(fhat2, n=n1))/n1
    var.fhat2 <- var(tail(fhat2, n=n2))
  }
  else
  {
    if (missing(var.fhat2))
    {
      S2 <- var(x2)
      H2.r1 <- Hamise.mixt(mus=rep(0,d), Sigmas=S2, samp=n2, props=1, deriv.order=1)
      if (binned)
      {
        fhat2.r1.est <- kdde(x=x2, H=H2.r1, deriv.order=1, binned=TRUE)
        fhat2.r1 <- matrix(0, nrow=1, ncol=d)
        for (i in 1:d)
          fhat2.r1[,i] <- find.nearest.gridpts(x=apply(x2,2, mean), gridx=fhat2.r1.est$eval.points, f=fhat2.r1.est$estimate[[i]])$fx
      }
      else
        fhat2.r1 <- kdde(x=x2, H=H2.r1, deriv.order=1, eval.points=apply(x2, 2, mean))$estimate
      var.fhat2 <- drop(fhat2.r1 %*% S2 %*% t(fhat2.r1))
    }
    psi21 <- eta.kfe.y(x=x2, G=H2, y=x1, verbose=verbose)
  }

  ## test statistic + its parameters
  
  T.hat <- drop(psi1 + psi2 - (psi12 + psi21))
  muT.hat <- (n1^(-1)*det(H1)^(-1/2) + n2^(-1)*det(H2)^(-1/2))*K0
  varT.hat <- 3*(n1*var.fhat1 + n2*var.fhat2)/(n1+n2) *(1/n1+1/n2) 
  zstat <- (T.hat-muT.hat)/sqrt(varT.hat)
  pval <- 1-pnorm(zstat)
  if (pval==0) pval <- pnorm(-abs(zstat)) 
 
  val <- list(Tstat=T.hat, zstat=zstat, pvalue=pval, mean=muT.hat, var=varT.hat, var.fhat1=var.fhat1, var.fhat2=var.fhat2, n1=n1, n2=n2, H1=H1, H2=H2, psi1=psi1, psi12=psi12, psi21=psi21, psi2=psi2)
  return(val)
}     
