# ************************************
# Author: Andreas Alfons
#         Erasmus University Rotterdam
# ************************************


#' Matrix completion via nuclear-norm regularization with hyperparameter tuning
#' 
#' Perform matrix completion via nuclear-norm regularization based on  
#' \code{\link[softImpute]{softImpute}()}.  The regularization parameter 
#' is thereby selected via repeated holdout validation or cross-validation.  
#' Note that this uses the convenience wrapper \code{\link{soft_impute}()}, 
#' whose default behavior is different from that of the original function.
#' 
#' @inheritParams soft_impute
#' @param splits  an object inheriting from class \code{"split_control"}, as 
#' generated by \code{\link{holdout_control}()} for repeated holdout validation 
#' or \code{\link{cv_folds_control}()} for \eqn{K}-fold cross-validation, or a 
#' list of index vectors giving different validation sets of observed cells as 
#' generated by \code{\link{create_splits}()}.  Cells in the validation set
#' will be set to \code{NA} for fitting the algorithm with the training set of 
#' observed cells.
#' @param \dots  additional arguments to be passed down to 
#' \code{\link{soft_impute}()}.
#' @param discretize  a logical indicating whether to include a discretization 
#' step after fitting the algorithm (defaults to \code{TRUE}).  In case of 
#' discrete rating-scale data, this can be used to map the imputed values to 
#' the discrete rating scale of the observed values.
#' 
#' @return 
#' An object of class \code{"soft_impute_tuned"} with the following components: 
#' \item{lambda}{a numeric vector containing the values of the regularization 
#' parameter.}
#' \item{tuning_loss}{a numeric vector containing the (average) values of the 
#' loss function on the validation set(s) for each value of the regularization 
#' parameter.}
#' \item{lambda_opt}{numeric; the optimal value of the regularization 
#' parameter.}
#' \item{fit}{an object of class \code{"\link{soft_impute}"} containing the 
#' results from the algorithm with the optimal regularization parameter on the 
#' full (observed) data matrix.}
#' 
#' The class structure is still experimental and may change in the future. 
#' The following accessor functions are available:
#' \itemize{
#'   \item \code{\link{get_completed}()} to extract the imputed data matrix (with the 
#'   optimal value of the regularization parameter),
#'   \item \code{\link{get_lambda}()} to extract the optimal value of the 
#'   regularization parameter.
#' }
#' 
#' @author Andreas Alfons
#' 
#' @inherit soft_impute references
#' 
#' @seealso 
#' \code{\link{soft_impute}()}, \code{\link{fraction_grid}()}, 
#' 
#' \code{\link{holdout_control}()}, \code{\link{cv_folds_control}()}, 
#' \code{\link{create_splits}()}
#' 
#' @examples
#' # toy example derived from MovieLens 100K dataset
#' data("MovieLensToy")
#' # Soft-Impute with discretization step and hyperparameter tuning
#' set.seed(20250723)
#' fit <- soft_impute_tune(MovieLensToy, 
#'                         lambda = fraction_grid(nb_lambda = 6, 
#'                                                reverse = TRUE),
#'                         splits = holdout_control(R = 5))
#' # extract discretized completed matrix with optimal 
#' # regularization parameter
#' X_hat <- get_completed(fit)
#' head(X_hat)
#' # extract optimal value of regularization parameter
#' get_lambda(fit)
#' 
#' @keywords multivariate
#' 
#' @export

soft_impute_tune <- function(X, lambda = fraction_grid(reverse = TRUE), 
                             relative = TRUE, splits = holdout_control(), 
                             ..., discretize = TRUE, values = NULL) {
  
  # initializations
  X <- as.matrix(X)
  
  # check values of tuning parameter
  lambda <- sort(unique(lambda), decreasing = TRUE)
  if (length(lambda) == 1L) {
    stop("only one value of 'lambda'; use function soft_impute() instead")
  }
  relative <- isTRUE(relative)
  
  # construct index vector of observed values
  observed <- which(!is.na(X))
  
  # create splits for tuning parameter validation
  if (inherits(splits, "split_control")) {
    splits <- create_splits(observed, control = splits)
  }
  
  # apply soft_impute() to the different training data sets
  fit_train <- lapply(splits, function(indices, ...) {
    # create training data where elements from test set are set to NA
    X_train <- X
    X_train[indices] <- NA_real_
    # apply soft_impute() to training data
    soft_impute(X_train, lambda = lambda, relative = relative, ..., 
                discretize = FALSE)
  }, ...)

  # extract predictions for the elements in the different test sets and compute 
  # prediction loss
  tuning_loss <- mapply(function(indices, fit) {
    # extract elements from test set as a vector
    X_test <- X[indices]
    # compute prediction loss (squared error)
    sapply(fit$X, function(X_hat) (X_test - X_hat[indices])^2)
  }, indices = splits, fit = fit_train, SIMPLIFY = FALSE, USE.NAMES = FALSE)
  # combine results for tuning parameter validation into one matrix
  # (each column holds the prediction losses for the corresponding lambda)
  tuning_loss <- do.call(rbind, tuning_loss)
  # compute column means (mean squared error)
  tuning_loss <- colMeans(tuning_loss)
  
  # select the optimal lambda: in the unlikely case of ties, we select the 
  # lambda with stronger penalization
  which_opt <- seq_along(lambda)[which.min(tuning_loss)]
  lambda_opt <- lambda[which_opt]
  
  # apply soft_impute() with optimal tuning parameter
  fit_opt <- soft_impute(X, lambda = lambda_opt, relative = relative, ..., 
                         discretize = discretize, values = values)

  ## construct list of relevant output
  out <- list(lambda = lambda, tuning_loss = tuning_loss, 
              lambda_opt = lambda_opt, fit = fit_opt)
  class(out) <- "soft_impute_tuned"
  out
  
}
