## util.R
##
##   Copyright (C) 2018, 2019, David Bolin
##
##   This program is free software: you can redistribute it and/or modify
##   it under the terms of the GNU General Public License as published by
##   the Free Software Foundation, either version 3 of the License, or
##   (at your option) any later version.
##
##   This program is distributed in the hope that it will be useful,
##   but WITHOUT ANY WARRANTY; without even the implied warranty of
##   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##   GNU General Public License for more details.
##
##   You should have received a copy of the GNU General Public License
##   along with this program.  If not, see <http://www.gnu.org/licenses/>.


#' Observation matrix for finite element discretization on R
#'
#' A finite element discretization on R can be written as
#' \eqn{u(s) = \sum_i^n u_i \varphi_i(s)}{u(s) = \sum_i^n u_i \varphi_i(s)}
#' where \eqn{\varphi_i(s)} is a piecewise linear
#' "hat function" centered at location
#' \eqn{x_i}{x_i}. This function computes an
#' \eqn{m\times n}{m x n} matrix \eqn{A}{A}
#' that links the basis function in the expansion to specified locations
#' \eqn{s = (s_1,\ldots, s_m)} in the domain through
#' \eqn{A_ij = \varphi_j(s_i)}{A_ij = \varphi_j(s_i)}.
#'
#' @param x The locations of the nodes in the FEM discretization.
#' @param loc The locations \eqn{(s_1,\ldots, s_m)}
#'
#' @return The sparse matrix `A`.
#' @export
#' @author David Bolin \email{davidbolin@@gmail.com}
#' @seealso [rSPDE.fem1d()]
#'
#' @examples
#' # create mass and stiffness matrices for a FEM discretization on [0,1]
#' x <- seq(from = 0, to = 1, length.out = 101)
#' fem <- rSPDE.fem1d(x)
#'
#' # create the observation matrix for some locations in the domain
#' obs.loc <- runif(n = 10, min = 0, max = 1)
#' A <- rSPDE.A1d(x, obs.loc)
rSPDE.A1d <- function(x, loc) {
  if (min(loc) < min(x) || max(loc) > max(x)) {
    stop("locations outside support of basis")
  }

  n.x <- length(x)
  n.loc <- length(loc)
  i <- as.vector(cbind(1:n.loc, 1:n.loc))
  j <- matrix(0, n.loc, 2)
  vals <- matrix(1, n.loc, 2)
  for (ii in seq_len(n.loc)) {
    j[ii, 1] <- sum(sum((loc[ii] - x) >= 0))
    vals[ii, 1] <- loc[ii] - x[j[ii, 1]]
    j[ii, 2] <- j[ii, 1] + 1
    if (j[ii, 2] <= n.x) {
      vals[ii, 2] <- x[j[ii, 2]] - loc[ii]
    } else {
      j[ii, 2] <- j[ii, 2] - 2
    }
  }
  j <- as.vector(j)
  vals <- as.vector(matrix(1 - vals / rowSums(vals)))

  A <- sparseMatrix(i = i, j = j, x = vals, dims = c(n.loc, n.x))
  return(A)
}


#' Finite element calculations for problems on R
#'
#' This function computes mass and stiffness matrices
#' for a FEM approximation on R, assuming
#' Neumann boundary conditions.
#' These matrices are needed when discretizing the
#' operators in rational approximations.
#'
#' @param x Locations of the nodes in the FEM approximation.
#'
#' @return The function returns a list with the following elements
#' \item{G }{The stiffness matrix with elements \eqn{(\nabla \phi_i, \nabla \phi_j)}.}
#' \item{C }{The mass matrix with elements \eqn{(\phi_i, \phi_j)}.}
#' \item{Cd }{Mass lumped mass matrix.}
#' \item{B }{Matrix with elements \eqn{(\nabla \phi_i, \phi_j)}.}
#' @export
#' @author David Bolin \email{davidbolin@@gmail.com}
#' @seealso [rSPDE.A1d()]
#' @examples
#' # create mass and stiffness matrices for a FEM discretization on [0,1]
#' x <- seq(from = 0, to = 1, length.out = 101)
#' fem <- rSPDE.fem1d(x)
rSPDE.fem1d <- function(x) {
  n <- length(x)
  d <- c(Inf, diff(x))
  dm1 <- c(d[2:n], Inf)
  G <- -bandSparse(
    n = n, m = n, k = c(-1, 0, 1),
    diagonals = cbind(1 / dm1, -(1 / dm1 + 1 / d), 1 / dm1)
  )
  C <- bandSparse(
    n = n, m = n, k = c(-1, 0, 1),
    diagonals = cbind(dm1 / 6, (dm1 + d) / 3, c(d[2:n], Inf) / 6)
  )
  C[1, 1:2] <- c(d[2], d[2] / 2) / 3
  C[n, (n - 1):n] <- c(d[n] / 2, d[n]) / 3

  Cd <- Diagonal(rowSums(C),n=n)

  B <- bandSparse(n = n, m = n, k = c(-1, 0, 1),
                  diagonals = cbind(rep(0.5,n), rep(0,n), rep(-0.5,n)))

  return(list(G = G, C = C, Cd = Cd, B = B))
}

#' Finite element calculations for problems in 2D
#'
#' This function computes mass and stiffness matrices for a mesh in 2D, assuming
#' Neumann boundary conditions.
#'
#' @param FV Matrix where each row defines a triangle
#' @param P Locations of the nodes in the mesh.
#'
#' @return The function returns a list with the following elements
#' \item{G }{The stiffness matrix with elements \eqn{(\nabla \phi_i, \nabla \phi_j)}.}
#' \item{C }{The mass matrix with elements \eqn{(\phi_i, \phi_j)}.}
#' \item{Cd }{The mass lumped matrix with diagonal elements \eqn{(\phi_i, 1)}.}
#' \item{Hxx }{Matrix with elements \eqn{(\partial_x \phi_i, \partial_x \phi_j)}.}
#' \item{Hyy }{Matrix with elements \eqn{(\partial_y \phi_i, \partial_y \phi_j)}.}
#' \item{Hxy }{Matrix with elements \eqn{(\partial_x \phi_i, \partial_y \phi_j)}.}
#' \item{Hyx }{Matrix with elements \eqn{(\partial_y \phi_i, \partial_x \phi_j)}.}
#' \item{Bx }{Matrix with elements \eqn{(\partial_x \phi_i, \phi_j)}.}
#' \item{By }{Matrix with elements \eqn{(\partial_y \phi_i, \phi_j)}.}
#' @export
#' @author David Bolin \email{davidbolin@@gmail.com}
#' @seealso [rSPDE.fem1d()]
#' @examples
#' P <- rbind(c(0, 0), c(1, 0), c(1, 1), c(0, 1))
#' FV <- rbind(c(1, 2, 3), c(2, 3, 4))
#' fem <- rSPDE.fem2d(FV, P)
rSPDE.fem2d <- function(FV, P) {
  d <- ncol(FV) - 1
  if (d != 2) {
    stop("Only 2d supported")
  }
  if (ncol(P) != d) {
    P <- t(P)
  }
  if (ncol(P) != d) {
    stop("Wrong dimension of P")
  }

  nV <- nrow(P)
  nF <- nrow(FV)
  Gi <- matrix(0, nrow = nF * 3, ncol = 3)
  Gj <- Gz <- Ci <- Cj <- Cz <- Gxx <- Gxy <- Gyx <- Gyy <- Bxz <- Byz <- Gi

  Mxx <- matrix(c(1, -1, 0, -1, 1, 0, 0, 0, 0), 3, 3)
  Myy <- matrix(c(1, 0, -1, 0, 0, 0, -1, 0, 1), 3, 3)
  Mxy <- matrix(c(1, -1, 0, 0, 0, 0, -1, 1, 0), 3, 3)
  Myx <- matrix(c(1, 0, -1, -1, 0, 1, 0, 0, 0), 3, 3)
  for (f in 1:nF) {
    dd <- 3 * (f - 1) + (1:3)
    Gi[dd, ] <- Ci[dd, ] <- FV[f, ] %*% t(rep(1, 3))
    Gj[dd, ] <- Cj[dd, ] <- t(Gi[dd, ])

    xy <- t(P[FV[f, ], ])
    m1 <- rbind(rep(1, 3), xy)
    m2 <- rbind(rep(0, 2), diag(1, 2))
    m <- solve(m1, m2)
    ddet <- abs(det(m1))
    Gz[dd, ] <- ddet * (m %*% t(m)) / 2
    Cz[dd, ] <- ddet * (rep(1, 3) + diag(3)) / 24

    Bk <- matrix(c(
      xy[1, 2] - xy[1, 1],
      xy[2, 2] - xy[2, 1],
      xy[1, 3] - xy[1, 1],
      xy[2, 3] - xy[2, 1]
    ), 2, 2)

    Bki <- solve(Bk)
    Cxx <- Bki %*% matrix(c(1, 0, 0, 0), 2, 2) %*% t(Bki)
    Cyy <- Bki %*% matrix(c(0, 0, 0, 1), 2, 2) %*% t(Bki)
    Cxy <- Bki %*% matrix(c(0, 0, 1, 0), 2, 2) %*% t(Bki)
    Cyx <- Bki %*% matrix(c(0, 1, 0, 0), 2, 2) %*% t(Bki)

    Gxx[dd, ] <- ddet * (Cxx[1, 1] * Mxx + Cxx[1, 2] * Mxy + Cxx[2, 1] * Myx + Cxx[2, 2] * Myy) / 2
    Gyy[dd, ] <- ddet * (Cyy[1, 1] * Mxx + Cyy[1, 2] * Mxy + Cyy[2, 1] * Myx + Cyy[2, 2] * Myy) / 2
    Gxy[dd, ] <- ddet * (Cxy[1, 1] * Mxx + Cxy[1, 2] * Mxy + Cxy[2, 1] * Myx + Cxy[2, 2] * Myy) / 2
    Gyx[dd, ] <- ddet * (Cyx[1, 1] * Mxx + Cyx[1, 2] * Mxy + Cyx[2, 1] * Myx + Cyx[2, 2] * Myy) / 2

    ab1 <- solve(matrix(c(xy[1,2]-xy[1,1], xy[1,3]-xy[1,1],
                          xy[2,2]-xy[2,1], xy[2,3]-xy[2,1]),2,2),rep(1,2))
    ab2 <- solve(matrix(c(xy[1,1]-xy[1,2], xy[1,3]-xy[1,2],
                          xy[2,1]-xy[2,2], xy[2,3]-xy[2,2]),2,2),rep(1,2))
    ab3 <- solve(matrix(c(xy[1,1]-xy[1,3], xy[1,2]-xy[1,3],
                          xy[2,1]-xy[2,3], xy[2,2]-xy[2,3]),2,2),rep(1,2))

    Bxz[dd, ] <-  -c(ab1[1], ab2[1], ab3[1]) * ddet/6
    Byz[dd, ] <-  -c(ab1[2], ab2[2], ab3[2]) * ddet/6

  }

  G <- Matrix::sparseMatrix(
    i = as.vector(Gi), j = as.vector(Gj),
    x = as.vector(Gz), dims = c(nV, nV)
  )
  Hxx <- Matrix::sparseMatrix(
    i = as.vector(Gi), j = as.vector(Gj),
    x = as.vector(Gxx), dims = c(nV, nV)
  )
  Hyy <- Matrix::sparseMatrix(
    i = as.vector(Gi), j = as.vector(Gj),
    x = as.vector(Gyy), dims = c(nV, nV)
  )
  Hxy <- Matrix::sparseMatrix(
    i = as.vector(Gi), j = as.vector(Gj),
    x = as.vector(Gxy), dims = c(nV, nV)
  )
  Hyx <- Matrix::sparseMatrix(
    i = as.vector(Gi), j = as.vector(Gj),
    x = as.vector(Gyx), dims = c(nV, nV)
  )

  Bx <- Matrix::sparseMatrix(
      i = as.vector(Gi), j = as.vector(Gj),
      x = as.vector(Bxz), dims = c(nV, nV)
  )

  By <- Matrix::sparseMatrix(
      i = as.vector(Gi), j = as.vector(Gj),
      x = as.vector(Byz), dims = c(nV, nV)
  )

  Ce <- Matrix::sparseMatrix(
    i = as.vector(Ci), j = as.vector(Cj),
    x = as.vector(Cz), dims = c(nV, nV)
  )
  C <- Matrix::Diagonal(n = nV, x = Matrix::colSums(Ce))
  return(list(G = G, C = Ce, Cd = C,
              Hxx = Hxx, Hyy = Hyy, Hxy = Hxy, Hyx = Hyx,
              Bx = Bx, By = By))
}
#' Warnings free loading of add-on packages
#'
#' Turn off all warnings for require(), to allow clean completion
#' of examples that require unavailable Suggested packages.
#'
#' @param package The name of a package, given as a character string.
#' @param lib.loc a character vector describing the location of R library trees
#' to search through, or `NULL`.  The default value of `NULL`
#' corresponds to all libraries currently known to `.libPaths()`.
#' Non-existent library trees are silently ignored.
#' @param character.only a logical indicating whether `package` can be
#' assumed to be a character string.
#'
#' @return `require.nowarnings` returns (invisibly)
#' `TRUE` if it succeeds, otherwise `FALSE`
#' @details `require(package)` acts the same as
#' `require(package, quietly = TRUE)` but with warnings turned off.
#' In particular, no warning or error is given if the package is unavailable.
#' Most cases should use `requireNamespace(package,
#' quietly = TRUE)` instead,
#' which doesn't produce warnings.
#' @seealso [require()]
#' @export
#' @examples
#' ## This should produce no output:
#' if (require.nowarnings(nonexistent)) {
#'   message("Package loaded successfully")
#' }
#'
require.nowarnings <- function(
    package, lib.loc = NULL,
    character.only = FALSE) {
  if (!character.only) {
    package <- as.character(substitute(package))
  }
  suppressWarnings(
    require(package,
      lib.loc = lib.loc,
      quietly = TRUE,
      character.only = TRUE
    )
  )
}

#' @name get.initial.values.rSPDE
#' @title Initial values for log-likelihood optimization in rSPDE models
#' with a latent stationary Gaussian Matern model
#' @description Auxiliar function to obtain domain-based initial values for
#' log-likelihood optimization in rSPDE models
#' with a latent stationary Gaussian Matern model
#' @param mesh An in INLA mesh
#' @param mesh.range The range of the mesh.
#' @param graph.obj A `metric_graph` object. To be used in case both `mesh` and `mesh.range` are `NULL`.
#' @param dim The dimension of the domain.
#' @param B.sigma Matrix with specification of log-linear model for \eqn{\sigma}. Will be used if `parameterization = 'matern'`.
#' @param B.range Matrix with specification of log-linear model for \eqn{\rho}, which is a range-like parameter (it is exactly the range parameter in the stationary case). Will be used if `parameterization = 'matern'`.
#' @param parameterization Which parameterization to use? `matern` uses range, std. deviation and nu (smoothness). `spde` uses kappa, tau and nu (smoothness). The default is `matern`.
#' @param B.tau Matrix with specification of log-linear model for \eqn{\tau}. Will be used if `parameterization = 'spde'`.
#' @param B.kappa Matrix with specification of log-linear model for \eqn{\kappa}. Will be used if `parameterization = 'spde'`.
#' @param nu The smoothness parameter.
#' @param include.nu Should we also provide an initial guess for nu?
#' @param n.spde The number of basis functions in the mesh model.
#' @param log.scale Should the results be provided in log scale?
#' @param nu.upper.bound Should an upper bound for nu be considered?
#' @return A vector of the form (theta_1,theta_2,theta_3) or where
#' theta_1 is the initial guess for tau, theta_2 is the initial guess for kappa
#' and theta_3 is the initial guess for nu.
#' @export
#'

get.initial.values.rSPDE <- function(mesh = NULL, mesh.range = NULL,
                                     graph.obj = NULL,
                                     n.spde = 1,
                                     dim = NULL, B.tau = NULL, B.kappa = NULL,
                                     B.sigma = NULL, B.range = NULL, nu = NULL,
                                     parameterization = c("matern", "spde"),
                                     include.nu = TRUE, log.scale = TRUE,
                                     nu.upper.bound = NULL) {
  if (is.null(mesh) && is.null(mesh.range) && is.null(graph.obj)) {
    stop("You should either provide mesh, mesh.range or graph_obj!")
  }

  parameterization <- parameterization[[1]]

  if (!parameterization %in% c("matern", "spde")) {
    stop("parameterization should be either 'matern' or 'spde'!")
  }

  if (is.null(mesh) && is.null(graph.obj) && is.null(dim)) {
    stop("If you don't provide mesh, you have to provide dim!")
  }

  if (!is.null(mesh)) {
    if (!inherits(mesh, c("fm_mesh_1d", "fm_mesh_2d"))) {
      stop("The mesh should be created using fmesher!")
    }

    dim <- fmesher::fm_manifold_dim(mesh)
  }

  if (!is.null(graph.obj)) {
    if (!inherits(graph.obj, "metric_graph")) {
      stop("graph_obj should be a metric_graph object.")
    }
    dim <- 1
  }

  if (include.nu) {
    if (!is.null(nu.upper.bound)) {
      nu <- min(1, nu.upper.bound / 2)
    } else {
      nu <- 1
    }
  } else {
    if (is.null(nu)) {
      stop("If include.nu is FALSE, then nu must be provided!")
    }
  }

  if (parameterization == "matern") {
    if (is.null(B.sigma)) {
      B.sigma <- matrix(c(0, 1, 0), 1, 3)
    }
    if (is.null(B.range)) {
      B.range <- matrix(c(0, 0, 1), 1, 3)
    }

    if (is.null(graph.obj)) {
      param <- get_parameters_rSPDE(
        mesh = mesh,
        alpha = nu + dim / 2,
        B.tau = B.tau,
        B.kappa = B.kappa,
        B.sigma = B.sigma,
        B.range = B.range,
        nu.nominal = nu,
        alpha.nominal = nu + dim / 2,
        parameterization = parameterization,
        prior.std.dev.nominal = 1,
        prior.range.nominal = NULL,
        prior.tau = NULL,
        prior.kappa = NULL,
        theta.prior.mean = NULL,
        theta.prior.prec = 0.1,
        mesh.range = mesh.range,
        d = dim,
        n.spde = n.spde
      )
    } else {
      param <- get_parameters_rSPDE_graph(
        graph_obj = graph.obj,
        alpha = nu + 1 / 2,
        B.tau = B.tau,
        B.kappa = B.kappa,
        B.sigma = B.sigma,
        B.range = B.range,
        nu.nominal = nu,
        alpha.nominal = nu + 1 / 2,
        parameterization = parameterization,
        prior.std.dev.nominal = 1,
        prior.range.nominal = NULL,
        prior.tau = NULL,
        prior.kappa = NULL,
        theta.prior.mean = NULL,
        theta.prior.prec = 0.1
      )
    }

    initial <- param$theta.prior.mean
  } else {
    if (is.null(B.tau)) {
      B.tau <- matrix(c(0, 1, 0), 1, 3)
    }
    if (is.null(B.kappa)) {
      B.kappa <- matrix(c(0, 0, 1), 1, 3)
    }
    if (is.null(graph.obj)) {
      param <- get_parameters_rSPDE(
        mesh = mesh,
        alpha = nu + dim / 2,
        B.tau = B.tau,
        B.kappa = B.kappa,
        B.sigma = B.sigma,
        B.range = B.range,
        nu.nominal = nu,
        alpha.nominal = nu + dim / 2,
        parameterization = parameterization,
        prior.std.dev.nominal = 1,
        prior.range.nominal = NULL,
        prior.tau = NULL,
        prior.kappa = NULL,
        theta.prior.mean = NULL,
        theta.prior.prec = 0.1,
        mesh.range = mesh.range,
        d = dim,
        n.spde = n.spde
      )
    } else {
      param <- get_parameters_rSPDE_graph(
        graph_obj = graph.obj,
        alpha = nu + 1 / 2,
        B.tau = B.tau,
        B.kappa = B.kappa,
        B.sigma = B.sigma,
        B.range = B.range,
        nu.nominal = nu,
        alpha.nominal = nu + 1 / 2,
        parameterization = parameterization,
        prior.std.dev.nominal = 1,
        prior.range.nominal = NULL,
        prior.tau = NULL,
        prior.kappa = NULL,
        theta.prior.mean = NULL,
        theta.prior.prec = 0.1
      )
    }

    initial <- param$theta.prior.mean
  }

  if (include.nu) {
    initial <- c(initial, log(nu))
  }

  if (log.scale) {
    return(initial)
  } else {
    return(exp(initial))
  }
}




#' @name cut_decimals
#' @title Approximation function for covariance-based rSPDE models
#' @description Approximation function to be used to compute the
#' precision matrix for covariance-based rSPDE models
#' @param nu A real number
#' @return An approximation
#' @noRd

cut_decimals <- function(nu) {
  temp <- nu - floor(nu)
  if (temp < 10^(-3)) {
    temp <- 10^(-3)
  }
  if (temp > 0.999) {
    temp <- 0.999
  }
  return(temp)
}

#' @name check_class_inla_rspde
#' @title Check if the object inherits from inla_rspde class
#' @description Check if the object inherits from inla_rspde class
#' @param model A model to test if it inherits from inla_rspde
#' @return Gives an error if the object does not inherit from inla_rspde
#' @noRd

check_class_inla_rspde <- function(model) {
  if (!inherits(model, c("inla_rspde", "inla_rspde_matern1d"))) {
    stop("You should provide a rSPDE model!")
  }
}

#' @name fem_mesh_order_1d
#' @title Get fem_mesh_matrices for 1d inla.mesh objects
#' @description Get fem_mesh_matrices for 1d inla.mesh objects
#' @param inla_mesh An INLA mesh
#' @param m_order the order of the FEM matrices
#' @return A list with fem_mesh_matrices
#' @noRd


fem_mesh_order_1d <- function(inla_mesh, m_order) {
  # fem_mesh <- rSPDE.fem1d(inla_mesh[["loc"]])
  # mesh_1d <- fmesher::fm_mesh_1d(inla_mesh[["loc"]])
  # fem_mesh <- fmesher::fm_fem(mesh_1d)
  mesh_1d <- fm_mesh_1d(inla_mesh[["loc"]])
  fem_mesh <- fm_fem(mesh_1d)
  C <- fem_mesh$c0
  C <- Matrix::Diagonal(dim(C)[1], rowSums(C))
  C <- as(C, "TsparseMatrix")
  G <- fem_mesh$g1
  Gk <- list()
  Ci <- C
  Ci@x <- 1 / (C@x)

  GCi <- G %*% Ci
  Gk[[1]] <- G
  # determine how many G_k matrices we want to create
  if (m_order > 1) {
    for (i in 2:m_order) {
      Gk[[i]] <- GCi %*% Gk[[i - 1]]
    }
  }

  # create a list contains all the finite element related matrices
  fem_mesh_matrices <- list()
  fem_mesh_matrices[["c0"]] <- C

  for (i in 1:m_order) {
    fem_mesh_matrices[[paste0("g", i)]] <- Gk[[i]]
  }
  return(fem_mesh_matrices)
}

#' @name generic_fem_mesh_order
#' @title Get fem_mesh_matrices from C and G matrices
#' @description Get fem_mesh_matrices from C and G matrices
#' @param fem_matrices A list with objects C and G
#' @param m_order the order of the FEM matrices
#' @return A list with fem_mesh_matrices
#' @noRd


generic_fem_mesh_order <- function(fem_matrices, m_order) {
  C <- fem_matrices$C
  C <- Matrix::Diagonal(dim(C)[1], rowSums(C))
  C <- INLA::inla.as.sparse(C)
  # C <- as(C,"TsparseMatrix")
  G <- fem_matrices$G
  Gk <- list()
  Ci <- C
  Ci@x <- 1 / (C@x)

  GCi <- G %*% Ci
  Gk[[1]] <- G
  # determine how many G_k matrices we want to create
  if (m_order > 1) {
    for (i in 2:m_order) {
      Gk[[i]] <- GCi %*% Gk[[i - 1]]
    }
  }

  # create a list contains all the finite element related matrices
  fem_mesh_matrices <- list()
  fem_mesh_matrices[["c0"]] <- C

  for (i in 1:m_order) {
    fem_mesh_matrices[[paste0("g", i)]] <- Gk[[i]]
  }
  return(fem_mesh_matrices)
}


#' @name get.sparsity.graph.rspde
#' @title Sparsity graph for rSPDE models
#' @description Creates the sparsity graph for rSPDE models
#' @param mesh An INLA mesh, optional
#' @param fem_mesh_matrices A list containing the FEM-related matrices.
#' The list should contain elements C, G, G_2, G_3, etc. Optional,
#' should be provided if mesh is not provided.
#' @param dim The dimension, optional. Should be provided if mesh
#' is not provided.
#' @param nu The smoothness parameter
#' @param force_non_integer Should nu be treated as non_integer?
#' @param rspde.order The order of the covariance-based rational SPDE approach.
#' @return The sparsity graph for rSPDE models to be used in R-INLA interface.
#' @noRd

get.sparsity.graph.rspde <- function(mesh = NULL,
                                     fem_mesh_matrices = NULL,
                                     nu,
                                     force_non_integer = FALSE,
                                     rspde.order = 2,
                                     dim = NULL) {
  if (!is.null(mesh)) {
    dim <- fmesher::fm_manifold_dim(mesh)
    if (!fmesher::fm_manifold(mesh, c("R1", "R2"))) {
      # FL: Is this actually required? Is fm_fem() etc support not sufficient?
      stop("The mesh should be from a flat manifold.")
    }
  } else if (is.null(dim)) {
    stop("If an INLA mesh is not provided, you should provide the dimension!")
  }
  sharp <- TRUE
  alpha <- nu + dim / 2

  m_alpha <- floor(alpha)

  integer_alpha <- (alpha %% 1 == 0)

  if (force_non_integer) {
    integer_alpha <- FALSE
  }

  if (!is.null(fem_mesh_matrices)) {
    if (integer_alpha) {
      return(fem_mesh_matrices[[paste0("g", m_alpha)]])
    } else {
      if (sharp) {
        if (m_alpha > 0) {
          return(bdiag(
            kronecker(
              diag(rep(1, rspde.order)),
              fem_mesh_matrices[[paste0("g", m_alpha + 1)]]
            ),
            fem_mesh_matrices[[paste0("g", m_alpha)]]
          ))
        } else {
          return(bdiag(
            kronecker(
              diag(rep(1, rspde.order)),
              fem_mesh_matrices[["g1"]]
            ),
            fem_mesh_matrices[["c0"]]
          ))
        }
      } else {
        return(kronecker(
          diag(rep(1, rspde.order + 1)),
          fem_mesh_matrices[[paste0("g", m_alpha + 1)]]
        ))
      }
    }
  } else if (!is.null(mesh)) {
    if (integer_alpha) {
      # fem_mesh_matrices <- INLA::inla.mesh.fem(mesh, order = m_alpha)
      # fem_mesh_matrices <- fmesher::fm_fem(mesh, order = m_alpha)
      fem_mesh_matrices <- fm_fem(mesh, order = m_alpha)
      return(fem_mesh_matrices[[paste0("g", m_alpha)]])
    } else {
      if (dim == 2) {
        # fem_mesh_matrices <- INLA::inla.mesh.fem(mesh, order = m_alpha + 1)
        # fem_mesh_matrices <- fmesher::fm_fem(mesh, order = m_alpha + 1)
        fem_mesh_matrices <- fm_fem(mesh, order = m_alpha + 1)
      } else {
        fem_mesh_matrices <- fem_mesh_order_1d(mesh, m_order = m_alpha + 1)
      }


      if (sharp) {
        if (m_alpha > 0) {
          return(bdiag(
            kronecker(
              diag(rep(1, rspde.order)),
              fem_mesh_matrices[[paste0("g", m_alpha + 1)]]
            ),
            fem_mesh_matrices[[paste0("g", m_alpha)]]
          ))
        } else {
          return(bdiag(
            kronecker(
              diag(rep(1, rspde.order)),
              fem_mesh_matrices[["g1"]]
            ),
            fem_mesh_matrices[["c0"]]
          ))
        }
      } else {
        return(kronecker(
          diag(rep(1, rspde.order + 1)),
          fem_mesh_matrices[[paste0("g", m_alpha + 1)]]
        ))
      }
    }
  } else {
    stop("You should provide either mesh or fem_mesh_matrices!")
  }
}


#' @name build_sparse_matrix_rspde
#' @title Create sparse matrix from entries and graph
#' @description Create sparse matrix from entries and graph
#' @param entries The entries of the precision matrix
#' @param graph The sparsity graph of the precision matrix
#' @return index for rSPDE models.
#' @noRd

build_sparse_matrix_rspde <- function(entries, graph) {
  if (!is.null(graph)) {
    # graph <- as(graph, "dgTMatrix")
    graph <- as(graph, "TsparseMatrix")
    idx <- which(graph@i <= graph@j)
    Q <- Matrix::sparseMatrix(
      i = graph@i[idx], j = graph@j[idx], x = entries,
      symmetric = TRUE, index1 = FALSE
    )
  }
  return(Q)
}


#' @name analyze_sparsity_rspde
#' @title Analyze sparsity of matrices in the rSPDE approach
#' @description Auxiliar function to analyze sparsity of matrices
#' in the rSPDE approach
#' @param nu.upper.bound Upper bound for the smoothness parameter
#' @param dim The dimension of the domain
#' @param rspde.order The order of the rational approximation
#' @param fem_mesh_matrices A list containing FEM-related matrices.
#' The list should contain elements c0, g1, g2, g3, etc.
#' @param include_lower_order Logical. Should the lower-order terms
#' be included? They are needed for the cases
#' when alpha = nu + d/2 is integer or for when sharp is set to TRUE.
#' @param include_higher_order Logical. Should be included for when nu
#' is estimated or for when alpha = nu + d/2 is not an integer.
#' @return A list containing informations on sparsity of the precision matrices
#' @noRd

analyze_sparsity_rspde <- function(nu.upper.bound, dim, rspde.order,
                                   fem_mesh_matrices,
                                   include_lower_order = TRUE,
                                   include_higher_order = TRUE) {
  beta <- nu.upper.bound / 2 + dim / 4

  m_alpha <- floor(2 * beta)

  positions_matrices <- list()

  C_list <- symmetric_part_matrix(fem_mesh_matrices$c0)
  G_1_list <- symmetric_part_matrix(fem_mesh_matrices$g1)
  if (m_alpha < 2) {
    G_2_list <- symmetric_part_matrix(fem_mesh_matrices[["g2"]])
  }
  if (m_alpha > 1) {
    for (j in 2:(m_alpha)) {
      assign(
        paste0("G_", j, "_list"),
        symmetric_part_matrix(fem_mesh_matrices[[paste0("g", j)]])
      )
    }
  }

  if (include_higher_order) {
    assign(
      paste0("G_", m_alpha + 1, "_list"),
      symmetric_part_matrix(fem_mesh_matrices[[paste0(
        "g",
        m_alpha + 1
      )]])
    )

    positions_matrices[[1]] <- match(
      C_list$M,
      get(paste0("G_", m_alpha + 1, "_list"))[["M"]]
    )
  }

  idx_matrices <- list()

  idx_matrices[[1]] <- C_list$idx

  if (m_alpha > 0) {
    for (i in 1:m_alpha) {
      if (include_higher_order) {
        positions_matrices[[i + 1]] <- match(get(paste0(
          "G_", i,
          "_list"
        ))[["M"]], get(paste0(
          "G_", m_alpha + 1,
          "_list"
        ))[["M"]])
      }
      idx_matrices[[i + 1]] <- get(paste0("G_", i, "_list"))[["idx"]]
    }
  }

  if (include_higher_order) {
    idx_matrices[[m_alpha + 2]] <- get(paste0(
      "G_", m_alpha + 1,
      "_list"
    ))[["idx"]]
  }

  if (include_lower_order) {
    positions_matrices_less <- list()
    if (m_alpha > 0) {
      positions_matrices_less[[1]] <- match(C_list$M, get(paste0(
        "G_",
        m_alpha, "_list"
      ))[["M"]])
    } else {
      positions_matrices_less[[1]] <- match(C_list$M, get(paste0(
        "G_",
        1, "_list"
      ))[["M"]])
    }

    if (m_alpha > 1) {
      for (i in 1:(m_alpha - 1)) {
        positions_matrices_less[[i + 1]] <- match(get(paste0(
          "G_", i,
          "_list"
        ))[["M"]], get(paste0("G_", m_alpha, "_list"))[["M"]])
      }
    } else if (m_alpha == 1) {
      positions_matrices_less[[2]] <- seq_len(length(get(paste0(
        "G_",
        m_alpha, "_list"
      ))[["M"]]))
    }
  } else {
    positions_matrices_less <- NULL
  }

  return(list(
    positions_matrices = positions_matrices,
    idx_matrices = idx_matrices,
    positions_matrices_less = positions_matrices_less
  ))
}

#' @name symmetric_part_matrix
#' @title Gets the upper triangular part of a matrix
#' @description Gets the upper triangular part of a matrix
#' @param M A matrix or a sparse matrix
#' @return A sparse matrix formed by the upper triangular part of `M`.
#' @noRd

symmetric_part_matrix <- function(M) {
  # M <- as(M, "dgTMatrix")
  M <- as(M, "TsparseMatrix")
  idx <- which(M@i <= M@j)
  sM <- cbind(M@i[idx], M@j[idx])
  colnames(sM) <- NULL
  return(list(M = split(sM, seq(nrow(sM))), idx = idx))
}


#' @name get.roots
#' @title Get roots of the polynomials used in the operator based rational
#' approximation.
#' @description Get list with rational coefficients
#' @param order order of the rational approximation
#' @param beta value of beta to get the coefficients for.
#' @param type_interp Type of interpolation. Options are "linear" or "spline".
#' @return A list with coefficients.
#' @noRd
get.roots <- function(order, beta, type_interp = "linear") {
  if(!(order %in% c(1,2,3,4))) {
    stop("order must be one of the values 1,2,3,4.")
  }
  if (beta > 2) {
    beta <- beta - floor(beta - 1)
  }
  mt <- get(paste0("m", order, "table"))
  rb <- rep(0, order + 1)
  rc <- rep(0, order)
  if(type_interp == "linear"){
      if(order == 1) {
          rc = approx(mt$beta, mt[[paste0("rc")]], beta)$y
      } else {
          rc = sapply(1:order, function(i) {
              approx(mt$beta, mt[[paste0("rc.", i)]], beta)$y
          })
      }
      rb = sapply(1:(order+1), function(i) {
           approx(mt$beta, mt[[paste0("rb.", i)]], xout = beta)$y
      })
      factor = approx(mt$beta, mt$factor, xout = beta)$y
  } else if(type_interp == "spline") {
      if(order == 1) {
          rc = spline(mt$beta, mt[[paste0("rc")]], xout = beta)$y
      } else {
          rc = sapply(1:order, function(i) {
              spline(mt$beta, mt[[paste0("rc.", i)]], xout = beta)$y
          })
      }
      rb = sapply(1:(order+1), function(i) {
          spline(mt$beta, mt[[paste0("rb.", i)]], xout = beta)$y
      })
      factor = spline(mt$beta, mt$factor, xout = beta)$y
  } else {
      stop("invalid type. The options are 'linear' and 'spline'.")
  }
  return(list(rb = rb, rc = rc, factor = factor))
}

#' @name get_rational_coefficients
#' @title Get matrix with rational coefficients
#' @description Get matrix with rational coefficients
#' @param order order of the rational approximation
#' @param type_rational_approx Type of the rational
#' approximation. Options are "mix", "chebfun", "brasil", "chebfunLB" and "operator"
#' @return A matrix with rational approximations.
#' @noRd

get_rational_coefficients <- function(order, type_rational_approx) {
  if (type_rational_approx == "chebfun") {
    mt <- get(paste0("m", order, "t"))
  } else if (type_rational_approx == "brasil") {
    mt <- get(paste0("m_brasil", order, "t"))
  } else if (type_rational_approx == "chebfunLB") {
    mt <- get(paste0("m_chebfun", order, "t"))
  } else if(type_rational_approx == "mix"){
    mt_brasil <- get(paste0("m_brasil", order, "t"))
    mt_chebfun <- get(paste0("m", order, "t"))
    mt <- matrix(nrow = nrow(mt_brasil), ncol = ncol(mt_brasil))
    mt[1:500,] <- mt_brasil[1:500,]
    mt[501:999] <- mt_chebfun[501:999,]
  } else{
    stop("The options are 'mix', 'chebfun', 'brasil' and 'chebfunLB'!")
  }
  return(mt)
}


#' @name interp_rational_coefficients
#' @title Get list with interpolated rational coefficients
#' @description Get list with interpolated rational coefficients for specific
#' value of alpha.
#' @param order order of the rational approximation
#' @param type_rational_approx Type of the rational
#' approximation. Options are "chebfun", "brasil"
#' and "chebfunLB"
#' @param type_interp Type of interpolation. Options are "linear"
#' (linear interpolation), "log" (log-linear interpolation), "spline" (spline
#' interpolation) and "logspline" (log-spline interpolation).
#' @param alpha Value of alpha for the coefficients.
#' @return A list with rational approximations.
#' @noRd
interp_rational_coefficients <- function(order,
                                         type_rational_approx,
                                         type_interp = "spline",
                                         alpha){
    mt <- get_rational_coefficients(order = order,
                                    type_rational_approx=type_rational_approx)
    alpha <- cut_decimals(alpha)
    if(type_interp == "linear"){
        r = sapply(1:order, function(i) {
            approx(mt$alpha, mt[[paste0("r", i)]], alpha)$y
        })
        p = sapply(1:order, function(i) {
            approx(mt$alpha, mt[[paste0("p", i)]], alpha)$y
        })
        k = approx(mt$alpha, mt$k, cut_decimals(alpha))$y
    } else if (type_interp == "log"){
        r = sapply(1:order, function(i) {
            exp(approx(mt$alpha, log(mt[[paste0("r", i)]]), alpha)$y)
        })
        p = sapply(1:order, function(i) {
            -exp(approx(mt$alpha, log(-mt[[paste0("p", i)]]), alpha)$y)
        })
        k = exp(approx(mt$alpha, log(mt$k), alpha)$y)
    } else if(type_interp == "spline") {
        r = sapply(1:order, function(i) {
            spline(mt$alpha, mt[[paste0("r", i)]], xout = alpha)$y
        })
        p = sapply(1:order, function(i) {
            spline(mt$alpha, mt[[paste0("p", i)]], xout = alpha)$y
        })
        k = spline(mt$alpha, mt$k, xout = alpha)$y
    } else if(type_interp == "logspline") {
        r = sapply(1:order, function(i) {
            exp(spline(mt$alpha, log(mt[[paste0("r", i)]]), xout = alpha)$y)
        })
        p = sapply(1:order, function(i) {
            -exp(spline(mt$alpha, log(-mt[[paste0("p", i)]]), xout = alpha)$y)
        })
        k = exp(spline(mt$alpha, log(mt$k), xout = alpha)$y)
    } else {
        stop("invalid type. The options are 'linear', 'log', 'spline' and 'logspline'.")
    }
    return(list(k=k, r=r, p=p))
}

#' Changing the type of the rational approximation
#'
#' @param x A `CBrSPDE` or an `rpsde.inla` object
#' @param value The type of rational approximation.
#' The current options are "chebfun", "brasil" and "chebfunLB"
#'
#' @return An object of the same class with the new rational approximation.
#' @export
#'
`rational.type<-` <- function(x, value) {
  object <- x

  type_rational_approximation <- value
  type_rational_approximation <- type_rational_approximation[[1]]
  if (!(type_rational_approximation %in% c("chebfun", "brasil", "chebfunLB"))) {
    stop('The possible types are "chebfun", "brasil" and "chebfunLB"!')
  }
  if (inherits(object, "CBrSPDEobj")) {
    model <- update(x, type_rational_approximation = value)
  } else if (inherits(object, "inla_rspde")) {
    nu.upper.bound <- object$nu.upper.bound
    prior.nu.dist <- object$prior.nu.dist
    mesh <- object$mesh
    nu <- object[["nu"]]
    rspde.order <- object$rspde.order
    parameterization <- object$parameterization
    theta.prior.prec <- object$theta.prior.prec
    theta.prior.mean <- object$theta.prior.mean
    start.theta <- object$start.theta
    prior.nu <- object$prior.nu
    start.nu <- object$start.nu
    debug <- object$debug


    model <- rspde.matern(mesh,
      nu.upper.bound = nu.upper.bound,
      rspde.order = rspde.order,
      nu = nu,
      debug = debug,
      parameterization = parameterization,
      theta.prior.mean = theta.prior.mean,
      theta.prior.prec = theta.prior.prec,
      start.theta = start.theta,
      prior.nu = prior.nu,
      start.nu = start.nu,
      prior.nu.dist = prior.nu.dist,
      type.rational.approx = type_rational_approximation
    )
  } else {
    stop("The object must be of class 'CBrSPDE' or 'inla_rspde'!")
  }
  return(model)
}



#' Get type of rational approximation.
#'
#' @param object A `CBrSPDEobj` object or an `inla_rspde` object.
#'
#' @return The type of rational approximation.
#' @export
#'
rational.type <- function(object) {
  if (inherits(object, "CBrSPDEobj")) {
    return(object$type_rational_approximation)
  } else if (inherits(object, "inla_rspde")) {
    return(object$type.rational.approx)
  } else if (inherits(object, "rSPDEobj")) {
    return("chebfun")
  } else {
    stop("Not a valid rSPDE object!")
  }
}


#' Changing the order of the rational approximation
#'
#' @param x A `CBrSPDE` or an `rpsde.inla` object
#' @param value The order of rational approximation.
#'
#' @return An object of the same class with the new order
#' of rational approximation.
#' @export
#'
`rational.order<-` <- function(x, value) {
  object <- x

  rspde.order <- value
  rspde.order <- rspde.order[[1]]

  if (inherits(object, "CBrSPDEobj") || inherits(object, "rSPDEobj")) {
    model <- update(object, m = rspde.order)
  } else if (inherits(object, "inla_rspde")) {
    if (rspde.order > 0 && object$integer.nu) {
      warning("The order was not changed since there is no
      rational approximation (an integer model was
      considered).")
      return(object)
    }
    nu.upper.bound <- object$nu.upper.bound
    prior.nu.dist <- object$prior.nu.dist
    mesh <- object$mesh
    nu <- object[["nu"]]
    parameterization <- object$parameterization
    theta.prior.prec <- object$theta.prior.prec
    theta.prior.mean <- object$theta.prior.mean
    start.theta <- object$start.theta
    prior.nu <- object$prior.nu
    start.nu <- object$start.nu
    type_rational_approximation <- object$type.rational.approx
    debug <- object$debug


    model <- rspde.matern(mesh,
      nu.upper.bound = nu.upper.bound,
      rspde.order = rspde.order,
      nu = nu,
      debug = debug,
      parameterization = parameterization,
      theta.prior.mean = theta.prior.mean,
      theta.prior.prec = theta.prior.prec,
      start.theta = start.theta,
      prior.nu = prior.nu,
      start.nu = start.nu,
      prior.nu.dist = prior.nu.dist,
      type.rational.approx = type_rational_approximation
    )
  } else if (!is.null(attr(object, "inla_rspde_Amatrix"))) {
    n_temp <- ncol(object)
    old_rspde.order <- attr(object, "rspde.order")
    orig_dim <- n_temp / (old_rspde.order + 1)
    A <- object[, 1:orig_dim]
    Abar <- kronecker(matrix(1, 1, rspde.order + 1), A)
    attr(Abar, "inla_rspde_Amatrix") <- TRUE
    attr(Abar, "rspde.order") <- rspde.order
    integer_nu <- attr(object, "integer_nu")
    if (integer_nu && rspde.order > 0) {
      warning("The order was not changed since there is
      no rational approximation (an integer model was
      considered).")
      return(object)
    }
    attr(Abar, "integer_nu") <- integer_nu
    return(Abar)
  } else if (inherits(object, "inla_rspde_index")) {
    integer_nu <- attr(object, "integer_nu")

    if (integer_nu && rspde.order > 0) {
      warning("The order was not changed since there is
      no rational approximation (an integer model was
      considered).")
      return(object)
    }

    n_mesh <- attr(object, "n.mesh")
    name <- attr(object, "name")
    n.group <- attr(object, "n.group")
    n.repl <- attr(object, "n.repl")

    factor_rspde <- rspde.order + 1

    name.group <- paste(name, ".group", sep = "")
    name.repl <- paste(name, ".repl", sep = "")

    out <- list()
    out[[name]] <- as.vector(sapply(1:factor_rspde, function(i) {
      rep(rep(((i - 1) * n_mesh + 1):(i * n_mesh),
        times = n.group
      ), times = n.repl)
    }))
    out[[name.group]] <- rep(rep(rep(1:n.group, each = n_mesh),
      times = n.repl
    ), times = factor_rspde)
    out[[name.repl]] <- rep(rep(1:n.repl, each = n_mesh * n.group),
      times = factor_rspde
    )
    class(out) <- c("inla_rspde_index", class(out))
    attr(out, "rspde.order") <- rspde.order
    attr(out, "integer_nu") <- integer_nu
    attr(out, "n.mesh") <- n_mesh
    attr(out, "name") <- name
    attr(out, "n.group") <- n.group
    attr(out, "n.repl") <- n.repl
    return(out)
  } else {
    stop("The object must be of class 'CBrSPDE' or 'inla_rspde'!")
  }
  return(model)
}


#' Get the order of rational approximation.
#'
#' @param object A `CBrSPDEobj` object or an `inla_rspde` object.
#'
#' @return The order of rational approximation.
#' @export
#'
rational.order <- function(object) {
  if (inherits(object, "CBrSPDEobj") || inherits(object, "rSPDEobj")) {
    return(object$m)
  } else if (inherits(object, "inla_rspde")) {
    return(object$rspde.order)
  } else if (!is.null(attr(object, "inla_rspde_Amatrix"))) {
    return(attr(object, "rspde.order"))
  } else if (inherits(object, "inla_rspde_index")) {
    return(attr(object, "rspde.order"))
  } else {
    stop("Not a valid rSPDE object!")
  }
}


#' Check user input.
#'
#' @param param A parameter to validate.
#' @param label Label for the parameter (used in error messages).
#' @param lower_bound Optional lower bound for the parameter.
#' @param dim Expected dimension of the parameter (default is 1 for scalar).
#' @param upper_bound Optional upper bound for the parameter.
#'
#' @return The validated parameter.
#' @noRd
#'
rspde_check_user_input <- function(param, label, lower_bound = NULL, dim = 1, upper_bound = NULL) {
  if (!is.numeric(param)) {
    stop(paste(label, "should be a numeric value!"))
  }
  
  if (length(param) != dim) {
    if (dim == 1) {
      stop(paste(label, "should be a single numeric value!"))
    } else {
      stop(paste(label, "should have a length of", dim, "!"))
    }
  }
  
  if (!is.null(lower_bound) && any(param < lower_bound)) {
    stop(paste(label, "should be greater than or equal to", lower_bound, "!"))
  }
  
  if (!is.null(upper_bound) && any(param > upper_bound)) {
    stop(paste(label, "should be less than or equal to", upper_bound, "!"))
  }
  
  return(param)
}


#' Process inputs likelihood
#'
#' @param kappa kappa
#' @param tau tau
#' @param nu nu
#' @param sigma.e sigma.e
#'
#' @return List with the positions
#' @noRd

likelihood_process_inputs_spde <- function(kappa, tau, nu, sigma.e) {
  param_vector <- c("tau", "kappa", "nu", "sigma.e")
  if (!is.null(tau)) {
    param_vector <- setdiff(param_vector, "tau")
  }
  if (!is.null(kappa)) {
    param_vector <- setdiff(param_vector, "kappa")
  }
  if (!is.null(nu)) {
    param_vector <- setdiff(param_vector, "nu")
  }
  if (!is.null(sigma.e)) {
    param_vector <- setdiff(param_vector, "sigma.e")
  }
  if (length(param_vector) == 0) {
    stop("You should leave at least one parameter free.")
  }
  return(param_vector)
}

#' Process inputs likelihood
#'
#' @param kappa kappa
#' @param tau tau
#' @param nu nu
#' @param sigma.e sigma.e
#'
#' @return List with the positions
#' @noRd

likelihood_process_inputs_matern <- function(range, sigma, nu, sigma.e) {
  param_vector <- c("sigma", "range", "nu", "sigma.e")
  if (!is.null(sigma)) {
    param_vector <- setdiff(param_vector, "sigma")
  }
  if (!is.null(range)) {
    param_vector <- setdiff(param_vector, "range")
  }
  if (!is.null(nu)) {
    param_vector <- setdiff(param_vector, "nu")
  }
  if (!is.null(sigma.e)) {
    param_vector <- setdiff(param_vector, "sigma.e")
  }
  if (length(param_vector) == 0) {
    stop("You should leave at least one parameter free.")
  }
  return(param_vector)
}

#' Process parameters likelihood
#'
#' @param theta vector of parameters
#' @param param_vector vector of parameters to be used
#' @param which_par which parameter to consider
#' @param logscale log scale?
#'
#' @return The value in the correct scale
#' @noRd

likelihood_process_parameters <- function(theta, param_vector, which_par, logscale) {
  coord_par <- which(which_par == param_vector)
  if (logscale) {
    param_value <- exp(theta[[coord_par]])
  } else {
    param_value <- theta[[coord_par]]
  }
  return(param_value)
}


#' @noRd
# Get priors and starting values
# Based on INLA::param2.matern.orig()

get_parameters_rSPDE <- function(
    mesh, alpha,
    B.tau,
    B.kappa,
    B.sigma,
    B.range,
    nu.nominal,
    alpha.nominal,
    parameterization,
    prior.std.dev.nominal,
    prior.range.nominal,
    prior.tau,
    prior.kappa,
    theta.prior.mean,
    theta.prior.prec,
    mesh.range = NULL,
    d = NULL,
    n.spde = NULL) {
  if (!is.null(mesh)) {
    if (!inherits(mesh, c("fm_mesh_1d", "fm_mesh_2d"))) {
      stop("The mesh should be created using fmesher!")
    }

    d <- fmesher::fm_manifold_dim(mesh)
    n.spde <- fmesher::fm_dof(mesh)
  } else {
    if (is.null(d)) {
      stop("If you do not provide the mesh, you must provide the dimension!")
    }
    if (is.null(n.spde)) {
      stop("If you do not provide the mesh, you must provide n.spde!")
    }
  }

  if (is.null(B.tau) && is.null(B.sigma)) {
    stop("One of B.tau or B.sigma must not be NULL.")
  }
  if (is.null(B.kappa) && is.null(B.range)) {
    stop("One of B.kappa or B.range must not be NULL.")
  }



  if (parameterization == "spde") {
    n.theta <- ncol(B.kappa) - 1L

    B.kappa <- prepare_B_matrices(
      B.kappa, n.spde,
      n.theta
    )
    B.tau <- prepare_B_matrices(B.tau, n.spde, n.theta)
  } else if (parameterization == "matern") {
    n.theta <- ncol(B.sigma) - 1L

    B.sigma <- prepare_B_matrices(
      B.sigma, n.spde,
      n.theta
    )
    B.range <- prepare_B_matrices(
      B.range, n.spde,
      n.theta
    )

    B.kappa <- cbind(
      0.5 * log(8 * nu.nominal) - B.range[, 1],
      -B.range[, -1, drop = FALSE]
    )

    B.tau <- cbind(
      0.5 * (lgamma(nu.nominal) - lgamma(alpha.nominal) -
        d / 2 * log(4 * pi)) - nu.nominal * B.kappa[, 1] -
        B.sigma[, 1],
      -nu.nominal * B.kappa[, -1, drop = FALSE] -
        B.sigma[, -1, drop = FALSE]
    )
  } else if (parameterization == "matern2") {
    n.theta <- ncol(B.sigma) - 1L

    B.sigma <- prepare_B_matrices(
      B.sigma, n.spde,
      n.theta
    )
    B.range <- prepare_B_matrices(
      B.range, n.spde,
      n.theta
    )

    B.kappa <- -B.range

    B.tau <- cbind(
      0.5 * (lgamma(nu.nominal) - lgamma(alpha.nominal) -
        d / 2 * log(4 * pi)) - nu.nominal * B.kappa[, 1] -
        0.5 * B.sigma[, 1],
      -nu.nominal * B.kappa[, -1, drop = FALSE] -
        0.5 * B.sigma[, -1, drop = FALSE]
    )
  }


  if (is.null(theta.prior.prec)) {
    theta.prior.prec <- diag(0.1, n.theta, n.theta)
  } else {
    theta.prior.prec <- as.matrix(theta.prior.prec)
    if (ncol(theta.prior.prec) == 1) {
      theta.prior.prec <- diag(
        as.vector(theta.prior.prec),
        n.theta, n.theta
      )
    }
    if ((nrow(theta.prior.prec) != n.theta) || (ncol(theta.prior.prec) !=
      n.theta)) {
      stop(paste(
        "Size of theta.prior.prec is (", paste(dim(theta.prior.prec),
          collapse = ",", sep = ""
        ), ") but should be (",
        paste(c(n.theta, n.theta), collapse = ",", sep = ""),
        ")."
      ))
    }
  }


  if (is.null(theta.prior.mean)) {
    if (is.null(prior.range.nominal)) {
      if (is.null(mesh.range)) {
        mesh.range <- ifelse(d == 2, (max(c(diff(range(mesh$loc[
          ,
          1
        ])), diff(range(mesh$loc[, 2])), diff(range(mesh$loc[
          ,
          3
        ]))))), diff(mesh$interval))
      }
      prior.range.nominal <- mesh.range * 0.2
    }
    if (is.null(prior.kappa)) {
      prior.kappa <- sqrt(8 * nu.nominal) / prior.range.nominal
    }
    if (is.null(prior.tau)) {
      prior.tau <- sqrt(gamma(nu.nominal) / gamma(alpha.nominal) / ((4 *
        pi)^(d / 2) * prior.kappa^(2 * nu.nominal) * prior.std.dev.nominal^2))
    }
    if (n.theta > 0) {
      if (parameterization == "spde") {
        theta.prior.mean <- qr.solve(rbind(
          B.tau[, -1, drop = FALSE],
          B.kappa[, -1, drop = FALSE]
        ), c(log(prior.tau) -
          B.tau[, 1], log(prior.kappa) - B.kappa[, 1]))
      } else if (parameterization == "matern") {
        theta.prior.mean <- qr.solve(rbind(
          B.sigma[, -1, drop = FALSE],
          B.range[, -1, drop = FALSE]
        ), c(log(prior.std.dev.nominal) -
          B.sigma[, 1], log(prior.range.nominal) - B.range[, 1]))
      } else if (parameterization == "matern2") {
        theta.prior.mean <- qr.solve(rbind(
          B.sigma[, -1, drop = FALSE],
          B.range[, -1, drop = FALSE]
        ), c(2 * log(prior.std.dev.nominal) -
          B.sigma[, 1], -log(prior.kappa) - B.range[, 1]))
      }
    } else {
      theta.prior.mean <- rep(0, n.theta)
    }
  }
  param <- list(
    B.tau = B.tau,
    B.kappa = B.kappa, theta.prior.mean = theta.prior.mean,
    theta.prior.prec = theta.prior.prec
  )
  return(param)
}

#' @noRd
# Check B matrices and adjust the number of lines
# Based on INLA:::inla.spde.homogenise_B_matrix()

prepare_B_matrices <- function(B, n.spde, n.theta) {
  if (!is.numeric(B)) {
    stop("B matrix must be numeric.")
  }
  if (is.matrix(B)) {
    if ((nrow(B) != 1) && (nrow(B) != n.spde)) {
      stop(paste("B matrix must have either 1 or", as.character(n.spde), "rows."))
    }
    if ((ncol(B) != 1) && (ncol(B) != 1 + n.theta)) {
      stop(paste("B matrix must have 1 or", as.character(1 +
        n.theta), "columns."))
    }
    if (ncol(B) == 1) {
      return(cbind(as.vector(B), matrix(0, n.spde, n.theta)))
    } else if (ncol(B) == 1 + n.theta) {
      if (nrow(B) == 1) {
        return(matrix(as.vector(B), n.spde, 1 + n.theta,
          byrow = TRUE
        ))
      } else if (nrow(B) == n.spde) {
        return(B)
      }
    }
  } else {
    if ((length(B) == 1) || (length(B) == n.spde)) {
      return(cbind(B, matrix(0, n.spde, n.theta)))
    } else if (length(B) == 1 + n.theta) {
      return(matrix(B, n.spde, 1 + n.theta, byrow = TRUE))
    } else {
      stop(paste(
        "Length of B must be 1,", as.character(1 + n.theta),
        "or", as.character(n.spde)
      ))
    }
  }
  stop("Unrecognised structure for B matrix")
}



#' @noRd
# Get priors and starting values
# Based on INLA::param2.matern.orig()

get_parameters_rSPDE_graph <- function(
    graph_obj, alpha,
    B.tau,
    B.kappa,
    B.sigma,
    B.range,
    nu.nominal,
    alpha.nominal,
    parameterization,
    prior.std.dev.nominal,
    prior.range.nominal,
    prior.tau,
    prior.kappa,
    theta.prior.mean,
    theta.prior.prec) {
  if (!inherits(graph_obj, "metric_graph")) {
    stop("The graph object should be of class metric_graph!")
  }
  if (is.null(B.tau) && is.null(B.sigma)) {
    stop("One of B.tau or B.sigma must not be NULL.")
  }
  if (is.null(B.kappa) && is.null(B.range)) {
    stop("One of B.kappa or B.range must not be NULL.")
  }

  d <- 1
  n.spde <- nrow(graph_obj$mesh$C)

  if (parameterization == "spde") {
    n.theta <- ncol(B.kappa) - 1L
    B.kappa <- prepare_B_matrices(
      B.kappa, n.spde,
      n.theta
    )
    B.tau <- prepare_B_matrices(B.tau, n.spde, n.theta)
  } else if (parameterization == "matern") {
    n.theta <- ncol(B.sigma) - 1L
    B.sigma <- prepare_B_matrices(
      B.sigma, n.spde,
      n.theta
    )
    B.range <- prepare_B_matrices(
      B.range, n.spde,
      n.theta
    )

    B.kappa <- cbind(
      0.5 * log(8 * nu.nominal) - B.range[, 1],
      -B.range[, -1, drop = FALSE]
    )

    B.tau <- cbind(0.5 * (lgamma(nu.nominal) - lgamma(alpha.nominal) -
      d / 2 * log(4 * pi)) - nu.nominal * B.kappa[, 1] -
      B.sigma[, 1], -nu.nominal * B.kappa[, -1, drop = FALSE] -
      B.sigma[, -1, drop = FALSE])
  }


  if (is.null(theta.prior.prec)) {
    theta.prior.prec <- diag(0.1, n.theta, n.theta)
  } else {
    theta.prior.prec <- as.matrix(theta.prior.prec)
    if (ncol(theta.prior.prec) == 1) {
      theta.prior.prec <- diag(
        as.vector(theta.prior.prec),
        n.theta, n.theta
      )
    }
    if ((nrow(theta.prior.prec) != n.theta) || (ncol(theta.prior.prec) !=
      n.theta)) {
      stop(paste(
        "Size of theta.prior.prec is (", paste(dim(theta.prior.prec),
          collapse = ",", sep = ""
        ), ") but should be (",
        paste(c(n.theta, n.theta), collapse = ",", sep = ""),
        ")."
      ))
    }
  }


  if (is.null(theta.prior.mean)) {
    if (is.null(prior.range.nominal)) {
      if (is.null(graph_obj$geo_dist)) {
        graph_obj$compute_geodist(obs = FALSE)
      } else if (is.null(graph_obj$geo_dist[[".vertices"]])) {
        graph_obj$compute_geodist(obs = FALSE)
      }
      finite_geodist <- is.finite(graph_obj$geo_dist[[".vertices"]])
      finite_geodist <- graph_obj$geo_dist[[".vertices"]][finite_geodist]
      prior.range.nominal <- max(finite_geodist) * 0.2
    }
    if (is.null(prior.kappa)) {
      prior.kappa <- sqrt(8 * nu.nominal) / prior.range.nominal
    }
    if (is.null(prior.tau)) {
      prior.tau <- sqrt(gamma(nu.nominal) / gamma(alpha.nominal) / (4 *
        pi * prior.kappa^(2 * nu.nominal) * prior.std.dev.nominal^2))
    }
    if (n.theta > 0) {
      if (parameterization == "spde") {
        theta.prior.mean <- qr.solve(rbind(
          B.tau[, -1, drop = FALSE],
          B.kappa[, -1, drop = FALSE]
        ), c(log(prior.tau) -
          B.tau[, 1], log(prior.kappa) - B.kappa[, 1]))
      } else if (parameterization == "matern") {
        theta.prior.mean <- qr.solve(rbind(
          B.sigma[, -1, drop = FALSE],
          B.range[, -1, drop = FALSE]
        ), c(log(prior.std.dev.nominal) -
          B.sigma[, 1], log(prior.range.nominal) - B.range[, 1]))
      }
    } else {
      theta.prior.mean <- rep(0, n.theta)
    }
  }
  param <- list(
    B.tau = B.tau,
    B.kappa = B.kappa, theta.prior.mean = theta.prior.mean,
    theta.prior.prec = theta.prior.prec
  )
  return(param)
}



#' @noRd

# Function to convert B.sigma and B.range to B.tau and B.kappa

convert_B_matrices <- function(B.sigma, B.range, n.spde, nu.nominal, d) {
  n.theta <- ncol(B.sigma) - 1L

  alpha.nominal <- nu.nominal + d / 2

  B.sigma <- prepare_B_matrices(
    B.sigma, n.spde,
    n.theta
  )
  B.range <- prepare_B_matrices(
    B.range, n.spde,
    n.theta
  )

  B.kappa <- cbind(
    0.5 * log(8 * nu.nominal) - B.range[, 1],
    -B.range[, -1, drop = FALSE]
  )

  B.tau <- cbind(
    0.5 * (lgamma(nu.nominal) - lgamma(alpha.nominal) -
      d / 2 * log(4 * pi)) - nu.nominal * B.kappa[, 1] -
      B.sigma[, 1],
    -nu.nominal * B.kappa[, -1, drop = FALSE] -
      B.sigma[, -1, drop = FALSE]
  )

  return(list(B.tau = B.tau, B.kappa = B.kappa))
}

#' @noRd
# Change parameterization in rspde_lme to matern

change_parameterization_lme <- function(likelihood, d, nu, par, hessian
                                        # , improve_gradient, gradient_args
) {
  tau <- par[1]
  kappa <- par[2]

  C1 <- sqrt(8 * nu)
  C2 <- sqrt(gamma(nu) / ((4 * pi)^(d / 2) * gamma(nu + d / 2)))

  sigma <- C2 / (tau * kappa^nu)
  range <- C1 / kappa

  grad_par <- matrix(c(
    -C2 / (kappa^nu * sigma^2), 0,
    nu * range^(nu - 1) * C2 / (sigma * C1^nu),
    -C1 / range^2
  ), nrow = 2, ncol = 2)


  new_observed_fisher <- t(grad_par) %*% hessian %*% (grad_par)

  # No need to include the additional term as the gradient is approximately zero.
  # from some numerical experiments, the approximation without the additional term
  # seems to be better in general.

  # hess_par <- matrix(c(2*C2/(kappa^nu * sigma^3), 0,
  #                     -nu * C2/((sigma^2) * (C1^nu)) * range^(nu-1),
  #                     2*C1/range^3) , ncol=2, nrow=2)

  # if(!improve_gradient){
  #   grad_lik <- numDeriv::grad(likelihood, log(par), method = "simple", method.args = gradient_args)
  # } else{
  #   grad_lik <- numDeriv::grad(likelihood, log(par), method = "Richardson", method.args = gradient_args)
  # }

  # grad_lik <- c(1/tau, 1/kappa) * grad_lik

  # add_mat <- diag(grad_lik) %*% hess_par

  # add_mat <- 0.5 * (add_mat + t(add_mat))

  # new_observed_fisher <- new_observed_fisher + add_mat

  inv_fisher <- tryCatch(solve(new_observed_fisher), error = function(e) matrix(NA, nrow(new_observed_fisher), ncol(new_observed_fisher)))

  std_err <- sqrt(diag(inv_fisher))

  # new_lik <- function(theta){
  #       sigma <- exp(theta[1])
  #       range <- exp(theta[2])

  #       kappa <- C1/range
  #       tau <- C2/(sigma * kappa^nu)
  #       return(likelihood(log(c(tau,kappa))))
  # }

  # hess_tmp <- numDeriv::hessian(new_lik, log(c(sigma,range)))

  # hess_tmp <- diag(c(1/sigma, 1/range)) %*% hess_tmp %*% diag(c(1/sigma, 1/range))

  return(list(coeff = c(sigma, range), std_random = std_err))
}



#' @noRd
#'

return_same_input_type_matrix_vector <- function(v, orig_v) {
  if (isS4(orig_v)) {
    return(v)
  } else {
    v_out <- as.matrix(v)
    dim(v_out) <- dim(orig_v)
    return(v_out)
  }
}



#' find indices of the rows with all NA's in lists
#' @noRd
#'
idx_not_all_NA <- function(data_list) {
  data_list[[".edge_number"]] <- NULL
  data_list[[".distance_on_edge"]] <- NULL
  data_list[[".coord_x"]] <- NULL
  data_list[[".coord_y"]] <- NULL
  data_list[[".group"]] <- NULL
  data_names <- names(data_list)
  n_data <- length(data_list[[data_names[1]]])
  idx_non_na <- logical(n_data)
  for (i in 1:n_data) {
    na_idx <- lapply(data_list, function(dat) {
      return(is.na(dat[i]))
    })
    idx_non_na[i] <- !all(unlist(na_idx))
  }
  return(idx_non_na)
}

#' find indices of the rows with at least one NA's in lists
#' @noRd
#'
idx_not_any_NA <- function(data_list) {
  data_list[[".edge_number"]] <- NULL
  data_list[[".distance_on_edge"]] <- NULL
  data_list[[".coord_x"]] <- NULL
  data_list[[".coord_y"]] <- NULL
  data_list[[".group"]] <- NULL
  data_names <- names(data_list)
  n_data <- length(data_list[[data_names[1]]])
  idx_non_na <- logical(n_data)
  for (i in 1:n_data) {
    na_idx <- lapply(data_list, function(dat) {
      return(is.na(dat[i]))
    })
    idx_non_na[i] <- !any(unlist(na_idx))
  }
  return(idx_non_na)
}


#' @noRd
#'

select_indexes <- function(data, idx) {
  if (inherits(data, "SpatialPointsDataFrame")) {
    data <- data[idx, , drop = FALSE]
  } else {
    data <- lapply(data, function(dat) {
      if (is.null(dim(dat))) {
        return(dat[idx])
      } else {
        return(dat[idx, , drop = FALSE])
      }
    })
  }
  return(data)
}



#' Create train and test splits to be used in the `cross_validation` function
#'
#' Train and test splits
#'
#' @param data A `list`, `data.frame`, `SpatialPointsDataFrame` or `metric_graph_data` objects.
#' @param cv_type The type of the folding to be carried out. The options are `k-fold` for `k`-fold cross-validation, in which case the parameter `k` should be provided,
#' `loo`, for leave-one-out and `lpo` for leave-percentage-out, in this case, the parameter `percentage` should be given, and also the `number_folds`
#' with the number of folds to be done. The default is `k-fold`.
#' @param k The number of folds to be used in `k`-fold cross-validation. Will only be used if `cv_type` is `k-fold`.
#' @param percentage The percentage (from 1 to 99) of the data to be used to train the model. Will only be used if `cv_type` is `lpo`.
#' @param number_folds Number of folds to be done if `cv_type` is `lpo`.
#' @return A list with two elements, `train` containing the training indices and `test` containing indices.
#' @export

create_train_test_indices <- function(data, cv_type = c("k-fold", "loo", "lpo"),
                                      k = 5, percentage = 20, number_folds = 10) {
  if (inherits(data, "metric_graph_data")) {
    idx <- seq_len(nrow(as.data.frame(data)))
  } else {
    idx <- seq_len(nrow(data))
  }
  if (inherits(data, "SpatialPointsDataFrame")) {
    data_tmp <- data@data
    data_nonNA <- !is.na(data_tmp)
  } else if (inherits(data, "metric_graph_data")) {
    data_nonNA <- !is.na(as.data.frame(data))
  } else {
    data_nonNA <- !is.na(data)
  }
  idx_nonNA <- sapply(1:length(idx), function(i) {
    all(data_nonNA[i, ])
  })
  idx <- idx[idx_nonNA]
  if (cv_type == "k-fold") {
    # split idx into k
    folds <- cut(sample(idx), breaks = k, label = FALSE)
    test_list_idx <- lapply(1:k, function(i) {
      which(folds == i, arr.ind = TRUE)
    })
    test_list <- lapply(test_list_idx, function(idx_test) {
      idx[idx_test]
    })
    train_list <- lapply(1:k, function(i) {
      idx[-test_list_idx[[i]]]
    })
  } else if (cv_type == "loo") {
    train_list <- lapply(1:length(idx), function(i) {
      idx[-i]
    })
    # test_list <- lapply(1:length(idx), function(i){idx[i]})
    test_list <- as.list(idx)
  } else if (cv_type == "lpo") {
    test_list_idx <- list()
    n_Y <- length(idx)
    for (i in number_folds:1) {
      test_list_idx[[i]] <- sample(1:length(idx), size = (1 - percentage / 100) * n_Y)
    }
    train_list <- lapply(1:number_folds, function(i) {
      idx[-test_list_idx[[i]]]
    })
    test_list <- lapply(test_list_idx, function(idx_test) {
      idx[idx_test]
    })
  }
  return(list(train = train_list, test = test_list))
}

# Check for required packages
#' @noRd
check_packages <- function(packages, func) {
    are_installed <-vapply(packages,
                           function(x) {
                               requireNamespace(x, quietly = TRUE)
                               },
                           TRUE
        )
    if (any(!are_installed)) {
        stop(paste0("Needed package(s) ",
                    paste0("'", packages[!are_installed], "'", collapse = ", "),
                    " not installed, but are needed by ", func)
             )
    }
}

#' @noRd
# Get appropriate shared library
get_shared_library <- function(shared_lib) {
  if (shared_lib == "INLA") {
    return(INLA::inla.external.lib("rSPDE"))
  }
  if (shared_lib == "rSPDE") {
    rspde_lib <- system.file("shared", package = "rSPDE")
    return(ifelse(Sys.info()["sysname"] == "Windows",
                 paste0(rspde_lib, "/rspde_cgeneric_models.dll"),
                 paste0(rspde_lib, "/rspde_cgeneric_models.so")))
  }
  if (shared_lib == "detect") {
    rspde_lib_local <- system.file("shared", package = "rSPDE")
    lib_path <- ifelse(Sys.info()["sysname"] == "Windows",
                      paste0(rspde_lib_local, "/rspde_cgeneric_models.dll"),
                      paste0(rspde_lib_local, "/rspde_cgeneric_models.so"))
    return(if (file.exists(lib_path)) lib_path else INLA::inla.external.lib("rSPDE"))
  }
  stop("'shared_lib' must be 'INLA', 'rSPDE', or 'detect'")
}

#' @noRd
set_prior <- function(prior, default_mean, default_precision, p = 1) {
  # Validate default parameters
  if (!is.numeric(default_mean) || length(default_mean) != p) {
    stop(paste("default_mean must be a numeric vector of length equal to",p,"."))
  }
  if (!is.numeric(default_precision) || length(default_precision) != p || any(default_precision <= 0)) {
    stop(paste("default_precision must be a positive numeric vector of length equal to",p,"."))
  }

  # Return default prior if none is provided
  if (is.null(prior)) {
    return(list(mean = default_mean, precision = default_precision))
  }

  # Ensure prior only contains allowed elements
  allowed_elements <- c("mean", "precision")
  invalid_elements <- setdiff(names(prior), allowed_elements)
  if (length(invalid_elements) > 0) {
    warning(sprintf("Invalid elements in prior: %s. Only 'mean' and 'precision' are allowed.",
                    paste(invalid_elements, collapse = ", ")))
  }

  # Validate and set 'mean'
  if (!is.null(prior$mean)) {
    if (!is.numeric(prior$mean) || length(prior$mean) != p) {
      stop(sprintf("'mean' must be a numeric vector of length %d.", p))
    }
  } else {
    prior$mean <- default_mean  # Use default mean if not provided
  }

  # Validate and set 'precision'
  if (!is.null(prior$precision)) {
    if (!is.numeric(prior$precision) || length(prior$precision) != p || any(prior$precision <= 0)) {
      stop(sprintf("'precision' must be a positive numeric vector of length %d.", p))
    }
  } else {
    prior$precision <- default_precision  # Use default precision if not provided
  }

  return(prior)
}

handle_prior_nu <- function(prior.nu, nu.upper.bound, nu.prec.inc = 0.01, prior.nu.dist = "lognormal") {
  if (is.null(prior.nu)) {
    prior.nu <- list()
  }
  
  # Check and set loglocation
  if (is.null(prior.nu$loglocation)) {
    prior.nu$loglocation <- log(min(1, nu.upper.bound / 2))
  } else if (length(prior.nu$loglocation) != 1) {
    warning("'prior.nu$loglocation' has length > 1. Only the first element will be used.")
    prior.nu$loglocation <- prior.nu$loglocation[1]
  }
  
  # Check and set mean
  if (is.null(prior.nu[["mean"]])) {
    prior.nu[["mean"]] <- min(1, nu.upper.bound / 2)
  } else if (length(prior.nu[["mean"]]) != 1) {
    warning("'prior.nu$mean' has length > 1. Only the first element will be used.")
    prior.nu[["mean"]] <- prior.nu[["mean"]][1]
  }
  
  # Check and set prec
  if (is.null(prior.nu$prec)) {
    mu_temp <- prior.nu[["mean"]] / nu.upper.bound
    prior.nu$prec <- max(1 / mu_temp, 1 / (1 - mu_temp)) + nu.prec.inc
  } else if (length(prior.nu$prec) != 1) {
    warning("'prior.nu$prec' has length > 1. Only the first element will be used.")
    prior.nu$prec <- prior.nu$prec[1]
  }
  
  # Check and set logscale
  if (is.null(prior.nu[["logscale"]])) {
    prior.nu[["logscale"]] <- 1
  } else if (length(prior.nu[["logscale"]]) != 1) {
    warning("'prior.nu$logscale' has length > 1. Only the first element will be used.")
    prior.nu[["logscale"]] <- prior.nu[["logscale"]][1]
  }
  
  # Determine starting value for nu
  if (prior.nu.dist == "beta") {
    start.nu <- prior.nu[["mean"]]
  } else if (prior.nu.dist == "lognormal") {
    start.nu <- exp(prior.nu[["loglocation"]])
  } else {
    stop("prior.nu.dist should be either 'beta' or 'lognormal'.")
  }
  
  # Validate start.nu range
  if (start.nu > nu.upper.bound || start.nu < 0) {
    if (prior.nu.dist == "beta") {
      stop("The 'mean' element of 'prior.nu' should be a number between 0 and nu.upper.bound!")
    } else {
      stop("The 'loglocation' element of 'prior.nu' should be a number less than log(nu.upper.bound)!")
    }
  }
  
  return(list(prior.nu = prior.nu, start.nu = start.nu))
}

#' Transform Anisotropic SPDE Model Parameters to Original Scale
#'
#' @description
#' This function takes a vector of transformed parameters and applies the appropriate
#' transformations to return them in the original scale for use in anisotropic SPDE models.
#'
#' @param theta A numeric vector of length 4 or 5, containing the transformed parameters in this order:
#' \describe{
#'   \item{lhx}{The logarithmic representation of hx.}
#'   \item{lhy}{The logarithmic representation of hy.}
#'   \item{logit_hxy}{The logit-transformed representation of hxy.}
#'   \item{lsigma}{The logarithmic representation of sigma.}
#'   \item{lnu (optional)}{The logarithmic representation of nu. If not provided, nu is not returned.}
#' }
#' @param nu_upper_bound (optional) A numeric value representing the upper bound for the smoothness parameter nu.
#' This is only used, and must be provided, if `lnu` is provided.
#'
#' @return A named list with the parameters in the original scale:
#' \describe{
#'   \item{hx}{The original scale for hx (exponential of lhx).}
#'   \item{hy}{The original scale for hy (exponential of lhy).}
#'   \item{hxy}{The original scale for hxy (inverse logit transformation of logit_hxy).}
#'   \item{sigma}{The original scale for sigma (exponential of lsigma).}
#'   \item{nu (optional)}{The original scale for nu (using the forward_nu transformation). Only included if `lnu` is provided.}
#' }
#' @export
#'
#' @examples
#' # With lnu
#' theta <- c(log(0.1), log(0.2), log((0.3 + 1) / (1 - 0.3)), log(0.5), log(1))
#' nu_upper_bound <- 2
#' transform_parameters_anisotropic(theta, nu_upper_bound)
#'
#' # Without lnu
#' theta <- c(log(0.1), log(0.2), log((0.3 + 1) / (1 - 0.3)), log(0.5))
#' transform_parameters_anisotropic(theta)
transform_parameters_anisotropic <- function(theta, nu_upper_bound = NULL) {
  if (!(length(theta) %in% c(4, 5))) {
    stop("Theta must be a numeric vector of length 4 or 5.")
  }
  
  # Functions for transformations
  adjusted_inv_logit <- function(z) {
    (2 / (1 + exp(-z))) - 1
  }
  
  forward_nu <- function(lnu, nu_upper_bound) {
    exp(lnu) / (1 + exp(lnu)) * nu_upper_bound
  }
  
  # Extract parameters
  lhx <- theta[1]
  lhy <- theta[2]
  logit_hxy <- theta[3]
  lsigma <- theta[4]
  
  # Transform parameters to original scale
  hx <- exp(lhx)
  hy <- exp(lhy)
  hxy <- adjusted_inv_logit(logit_hxy)
  sigma <- exp(lsigma)
  
  # Prepare the output
  result <- list(hx = hx, hy = hy, hxy = hxy, sigma = sigma)
  
  # If lnu is provided, compute nu
  if (length(theta) == 5) {
    if (is.null(nu_upper_bound)) {
      stop("nu_upper_bound must be provided if lnu is included in theta.")
    }
    lnu <- theta[5]
    nu <- forward_nu(lnu, nu_upper_bound)
    result$nu <- nu
  }
  
  return(result)
}


#' @noRd 

rspde_check_cgeneric_symbol <- function(model) {
  # Ensure the required fields exist in the model object
  if (!"f" %in% names(model) || !"cgeneric" %in% names(model$f) || 
      !"shlib" %in% names(model$f$cgeneric) || !"model" %in% names(model$f$cgeneric)) {
    stop("There was a problem with the model creation.")
  }
  
  # Extract the shared library path and the symbol name
  shlib <- model$f$cgeneric$shlib
  symbol <- model$f$cgeneric$model
  
  # Check if the shared library exists
  if (!file.exists(shlib)) {
    stop(paste("The shared library", shlib, "does not exist."))
  }
  
  # Use the `dyn.load` and `is.loaded` functions to check for the symbol
  tryCatch({
    dyn.load(shlib) # Load the shared library
    if (is.loaded(symbol)) {
      dyn.unload(shlib) # Unload if the symbol is available
      return(invisible(TRUE)) # Return silently
    } else {
      warning(paste0("The symbol '", symbol, "' is not available in the shared library. Please install the latest testing version of INLA. 
      If the problem persists after installing the latest testing version of INLA, please open an issue at https://github.com/davidbolin/rSPDE/issues, 
      requesting that this model be added to INLA."))
    }
    dyn.unload(shlib) # Ensure the library is unloaded
  }, error = function(e) {
    warning(paste0("Error while loading the shared library or checking the symbol: ", e$message, 
                   ". Please install the latest testing version of INLA. If the problem persists after installing the 
                   latest testing version of INLA, please open an issue at https://github.com/davidbolin/rSPDE/issues, 
                   requesting that this model be added to INLA."))
  })
}

#' @noRd
match_with_tolerance <- function(input, loc, tolerance = 1e-6) {
  # Initialize a vector to store matched indices
  matched_indices <- integer(length(input))
  
  for (i in seq_along(input)) {
    # Find the indices in loc that match the current input element within the tolerance
    match_idx <- which(abs(loc - input[i]) <= tolerance)
    
    if (length(match_idx) == 0) {
      # If no match is found, throw an error
      stop(sprintf("Error: The input location %.10f is not present in the original locations used to create the model object.", input[i]))
    } else if (length(match_idx) > 1) {
      # Handle the case where multiple matches are found
      warning(sprintf("Warning: Multiple matches found for input location %.10f. Using the first match.", input[i]))
      match_idx <- match_idx[1]
    }
    
    # Store the matched index
    matched_indices[i] <- match_idx
  }
  
  return(matched_indices)
}


#' @noRd 
merge_with_tolerance <- function(original_data, new_data, by, tolerance = 1e-5) {
  # Ensure column names match by adding missing columns
  all_columns <- union(names(original_data), names(new_data))
  original_data[setdiff(all_columns, names(original_data))] <- NA
  new_data[setdiff(all_columns, names(new_data))] <- NA
  
  # Extract reference columns
  original_loc <- original_data[[by]]
  new_loc <- new_data[[by]]
  
  # Initialize the merged dataset
  merged_data <- original_data
  
  # Match rows from new_data to original_data within the tolerance
  for (i in seq_along(new_loc)) {
    diffs <- abs(original_loc - new_loc[i])
    if (any(diffs <= tolerance)) {
      # Find the closest match in original_data
      matched_index <- which.min(diffs)
      merged_row <- merged_data[matched_index, ]
      new_row <- new_data[i, ]
      
      # Exclude the `by` column from the merge
      columns_to_merge <- setdiff(names(new_data), by)
      
      # Check for conflicts and replace missing values in merged_row with new_row
      for (col in columns_to_merge) {
        if (!is.na(new_row[[col]])) {
          if (!is.na(merged_row[[col]]) && merged_row[[col]] != new_row[[col]]) {
            warning(sprintf(
              "Conflicting values in column '%s' for location '%s': original='%s', new='%s'. Using new value.",
              col, new_loc[i], merged_row[[col]], new_row[[col]]
            ))
          }
          merged_row[[col]] <- new_row[[col]]
        }
      }
      
      # Replace the row in merged_data
      merged_data[matched_index, ] <- merged_row
    } else {
      # Add unmatched rows from new_data directly
      merged_data <- rbind(merged_data, new_data[i, ])
    }
  }
  
  # Remove duplicates based on the `by` column
  merged_data <- merged_data[!duplicated(merged_data[[by]]), ]
  
  return(merged_data)
}


#' Transform Spacetime SPDE Model Parameters to Original Scale
#'
#' @description
#' This function takes a vector of transformed parameters and applies the appropriate
#' transformations to return them in the original scale for use in spacetime SPDE models.
#'
#' @param theta A numeric vector containing the transformed parameters in this order:
#' \describe{
#'   \item{lkappa}{The logarithmic representation of kappa.}
#'   \item{lsigma}{The logarithmic representation of sigma.}
#'   \item{lgamma}{The logarithmic representation of gamma.}
#'   \item{logit_rho (optional)}{The logit-transformed representation of rho, if drift = 1.}
#'   \item{logit_rho2 (optional)}{The logit-transformed representation of rho2, if drift = 1 and d = 2.}
#' }
#' @param st_model A list containing the spacetime model parameters:
#' \describe{
#'   \item{d}{The dimension (e.g., 1 or 2).}
#'   \item{bound}{The bound for rho and rho2.}
#'   \item{is_bounded}{A logical value indicating if rho and rho2 are bounded.}
#'   \item{drift}{A logical value indicating if drift is included in the model.}
#' }
#'
#' @return A named list with the parameters in the original scale:
#' \describe{
#'   \item{kappa}{The original scale for kappa (exponential of lkappa).}
#'   \item{sigma}{The original scale for sigma (exponential of lsigma).}
#'   \item{gamma}{The original scale for gamma (exponential of lgamma).}
#'   \item{rho (optional)}{The original scale for rho.}
#'   \item{rho2 (optional)}{The original scale for rho2, if d = 2.}
#' }
#' @export
transform_parameters_spacetime <- function(theta, st_model) {
  if (!is.list(st_model) || !all(c("d", "bound_rho", "is_bounded", "drift") %in% names(st_model))) {
    stop("st_model must be a list containing 'd', 'bound_rho', 'is_bounded', and 'drift'.")
  }
  
  # Extract model parameters
  d <- st_model$d
  bound <- st_model$bound_rho
  is_bounded <- st_model$is_bounded
  drift <- st_model$drift
  
  # Functions for transformations
  adjusted_inv_logit <- function(z, L) {
    if (L <= 0) stop("Bound L must be positive.")
    L * (2 / (1 + exp(-z)) - 1)
  }
  
  # Transform required parameters
  lkappa <- theta[1]
  lsigma <- theta[2]
  lgamma <- theta[3]
  kappa <- exp(lkappa)
  sigma <- exp(lsigma)
  gamma <- exp(lgamma)
  
  result <- list(kappa = kappa, sigma = sigma, gamma = gamma)
  
  # Include rho and rho2 if drift is included
  if (drift) {
    if (is_bounded) {
      logit_rho <- theta[4]
      rho <- adjusted_inv_logit(logit_rho, bound)
    } else {
      rho <- theta[4]
    }
    result$rho <- rho
    
    # Include rho2 if d = 2
    if (d == 2) {
      if (is_bounded) {
        logit_rho2 <- theta[5]
        rho2 <- adjusted_inv_logit(logit_rho2, bound)
      } else {
        rho2 <- theta[5]
      }
      result$rho2 <- rho2
    } else {
      result$rho2 <- 0.0
    }
  } else {
    result$rho <- 0.0
    result$rho2 <- 0.0
  }
  
  return(result)
}
