# -------------------------------------------------------------------------------
#   This file is part of 'unityForest'.
#
# 'unityForest' is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# 'unityForest' is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with 'unityForest'. If not, see <http://www.gnu.org/licenses/>.
#
#  NOTE: 'unityForest' is a fork of the popular R package 'ranger', written by Marvin N. Wright.
#  Most R and C++ code is identical with that of 'ranger'. The package 'unityForest'
#  was written by taking the original 'ranger' code and making any
#  changes necessary to implement diversity forests.
#
# -------------------------------------------------------------------------------

##' Implements the algorithm for selecting and visualizing covariate-representative tree roots (CRTRs) as described in Hornung & Hapfelmeier (2026).\cr
##' CRTRs are tree roots extracted from a unity forest that characterize the conditions under which a given variable exhibits its strongest effect on the outcome. The function selects one representative tree root for each variable and visualizes its structure to facilitate interpretation. CRTRs are essential for analyzing the effects identified by the unity VIM (\code{\link{unityfor}}). See the 'Details' section below for more details.
##'
##' @details
##' Further details on the descriptions below are provided in Hornung & Hapfelmeier (2026).
##' 
##' \strong{Covariate-representative tree roots (CRTRs).}
##' Covariate-representative tree roots (CRTRs) (Hornung & Hapfelmeier, 2026) are tree fragments (or 'tree roots' - the first few splits in the trees) extracted from a fitted unity forest (\code{\link{unityfor}}) that characterize for given variables the conditions under which each variable exerts its strongest influence on the prediction.
##' 
##' Technically, for a given variable, the algorithm identifies tree roots in which this variable attains particularly high split scores (top-scoring splits). From these tree roots, a representative root is extracted (Laabs et al., 2024) that best reflects the conditions under which this variable has its strongest effect.
##' 
##' \strong{Interpretation and subgroup effects.}
##' If a variable has a strong marginal effect, the corresponding CRTR typically contains a split on this variable at the root node (first split in the tree). In contrast, if a variable has little marginal effect but interacts with another variable, the CRTR may first split on that other variable, thereby defining a subgroup in which the variable of interest exhibits a strong conditional effect.
##' 
##' From a substantive perspective, CRTRs enable the exploration of variable effects that are generally not detectable by conventional methods focusing on marginal associations. In particular, CRTRs can reveal variables that have weak marginal effects but act strongly within specific subgroups defined by interactions with other variables.
##' 
##' \strong{Relation to unity VIM.}
##' CRTRs are closely related to the unity variable importance measure (unity VIM) (\code{\link{unityfor}}). The unity VIM quantifies the strength of variable effects under the conditions in which they are strongest. Analogously, CRTRs visualize these conditions by displaying the tree structures that give rise to the respective unity VIM values.
##' 
##' Accordingly, the CRTR algorithm can be used to visualize and interpret the effects identified by the unity VIM. By default, CRTRs are constructed and visualized for the five variables with the largest unity VIM values.
##' 
##' \strong{Scope of applicability.}
##' CRTRs should primarily be examined for variables with sufficiently large unity VIM values. Constructing CRTRs for variables with negligible importance may lead to overinterpretation, as apparent patterns may reflect random structure rather than meaningful effects.
##' 
##' \strong{Shaded regions in the visualization.}
##' For improved interpretability, parts of the CRTRs are shaded out by default. Specifically, only the nodes containing the top-scoring splits for the variable of interest and their ancestor nodes are shown prominently.
##' 
##' This design is motivated by two considerations. First, the purpose of CRTRs is to depict the conditions under which a variable exhibits its strongest effects - conditions that are defined by the ancestors of the nodes with top-scoring splits. Second, the remaining regions of the tree are of limited interpretive value. Since each CRTR is derived from tree roots selected for strong effects of a specific variable, the splitting patterns along the highlighted paths are specific  for that variable. In contrast, shaded regions reflect arbitrary aspects of the overall association structure in the data and may include splits on non-informative variables, as each tree root is grown from a (small) random subset of all available variables.
##' 
##' Note that additional splits on the variable of interest may occur within shaded regions and can still be relevant. However, these splits do not represent the conditions under which the variable attains its strongest effects.
##' 
##' \strong{In-bag data for top-scoring split visualizations.}
##' The boxplots and density plots illustrating the discriminatory power of the top-scoring splits are computed exclusively based on the in-bag observations of the corresponding trees. This is consistent with the construction of the CRTRs themselves, which are derived from in-bag data only.
##' 
##' @title Select and visualize covariate-representative tree roots (CRTRs)
##' 
##' @param object Object of class \code{unityfor}. 
##' @param vars This is an optional vector of variable names, for which CRTRs should be obtained
##' @param numvars The number of the variables with the largest unity VIM values for which CRTRs should be obtained.
##' @param indvars The indices of the variables with the largest unity VIM values for which CRTRs should be obtained. For example, if \code{indvars = c(1, 3)}, the CRTRs for the variables with the largest and third-largest unity VIM values are obtained.
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param plotit Whether or not the CRTRs should be plotted or merely returned (invisibly). Default is \code{TRUE}.
##' @param highlight_relevant Whether or not the nodes not containing the top-scoring splits for the variables of interest or their ancestor nodes should be shaded out. Default is \code{TRUE}. See the 'Details' section below for explanation.
##' @param box_plots Whether boxplots should be used to show the outcome class-specific distributions of the variables values in the nodes with top-scoring splits (see 'Details' section for explanation). For classification only. Default is \code{TRUE}.
##' @param density_plots Whether kernel density plots should be used to show the outcome class-specific distributions of the variable values in the nodes with top-scoring splits (see 'Details' section for explanation). For classification only. Default is \code{TRUE}.
##' @param add_split_line Whether in the boxplots and/or density plots a line at the split point of the corresponding node should be drawn. Default is \code{TRUE}.
##' @param verbose Verbose output on or off. Default is \code{TRUE}.
##' @return Object of class \code{unityfor.reprTrees} with elements
##'   \item{\code{rules}}{List. Ing-bag statistics on the outcome at each node in the CRTRs. For classification, this provides the class frequencies and the numbers of observations representing each class.}
##'   \item{\code{plots}}{List. Generated ggplot2 plots.}
##'   \item{\code{var.names}}{Labels of the variables for which CRTRs were selected.} 
##'   \item{\code{var.names.all}}{Names of all independent variables in the dataset.}
##'   \item{\code{num.independent.variables}}{Number of independent variables in the dataset.}
##'   \item{\code{num.samples}}{Number of observations in the dataset.}
##'   \item{\code{treetype}}{Tree type.}
##'   \item{\code{forest}}{Sub-forest that contains only the CRTRs.}
##' @examples
##' \donttest{
##'
##' ## Load package:
##' 
##' library("unityForest")
##' 
##' 
##' ## Set seed to make results reproducible:
##' 
##' set.seed(1234)
##' 
##' 
##' ## Load wine dataset:
##' 
##' data(wine)
##' 
##' 
##' ## Construct unity forest and calculate unity VIM values:
##' 
##' model <- unityfor(dependent.variable.name = "C", data = wine,
##'                   importance = "unity", num.trees = 2000)
##' 
##' # NOTE: num.trees = 2000 (in the above) would be too small for practical 
##' # purposes. This quite small number of trees was simply used to keep the
##' # runtime of the example short.
##' # The default number of trees is num.trees = 20000.
##' 
##' 
##' ## Visualize the CRTRs for the five variables with the largest unity VIM
##' ## values:
##' 
##' reprTrees(model, box_plots = FALSE, density_plots = FALSE)
##' 
##' 
##' ## Visualize the CRTRs for the variables with the largest and third-largest 
##' ## unity VIM values:
##' 
##' reprTrees(model, indvars = c(2, 3), box_plots = FALSE, density_plots = FALSE)
##' 
##' 
##' ## Visualize the CRTRs for the variables with the largest and third-largest 
##' ## unity VIM values, where density plots are shown to visualize the 
##' ## outcome class-specific distributions of the variables values in the 
##' ## nodes with top-scoring splits:
##' 
##' reprTrees(model, indvars = c(2, 3), box_plots = FALSE, density_plots = TRUE)
##' 
##' 
##' ## Visualize the CRTRs for the variables with the largest and third-largest 
##' ## unity VIM values, where both density plots and boxplots are shown to 
##' ## visualize the outcome class-specific distributions of the variables values 
##' ## in the top-scoring splits; the split points are not indicated in these
##' ## plots:

##' ps <- reprTrees(model, indvars = c(2, 3), add_split_line = FALSE)
##' 
##' 
##' ## Save one of the CRTRs with the corresponding density plot:
##' 
##' library("patchwork")
##' library("ggplot2")
##' 
##' p <- ps$plots[[1]]$tree_plot / ps$plots[[1]]$density_plot +
##'      patchwork::plot_layout(heights = c(2, 1))
##' p
##' 
##' # outfile <- file.path(tempdir(), "figure_xy.pdf")
##' # ggsave(outfile, device = cairo_pdf, plot = p, width = 18, 
##' #        height = 14)
##' 
##' 
##' # Note: The plots can be manipulated with the usual ggplot2 syntax, e.g.:
##' 
##' ps$plots[[1]]$density_plot + xlab("Proline") + labs(title = NULL, y = NULL) +
##'   theme(
##'     legend.position = c(0.95, 0.95),
##'     legend.justification = c(1, 1)
##'   )
##' 
##' }
##'
##' @author Roman Hornung
##' @references
##' \itemize{
##'   \item Hornung, R., Hapfelmeier, A. (2026). Unity Forests: Improving Interaction Modelling and Interpretability in Random Forests. arXiv:2601.07003, <\doi{10.48550/arXiv.2601.07003}>.
##'   \item Laabs, B.-H., Westenberger, A., & K\"onig, I. R. (2024). Identification of representative trees in random forests based on a new tree-based distance measure. Advances in Data Analysis and Classification 18(2):363-380, <\doi{10.1007/s11634-023-00537-7}>.
##'   }
##' @seealso \code{\link{unityfor}}
##' @encoding UTF-8
##' @useDynLib unityForest, .registration = TRUE
##' @importFrom Rcpp evalCpp
##' @importFrom rlang .data
##' @import stats 
##' @import utils
##' @export
reprTrees <- function(object, vars=NULL, numvars=5, indvars=NULL, num.threads = NULL, plotit=TRUE,
                         highlight_relevant = TRUE, box_plots = TRUE, density_plots = TRUE, add_split_line = TRUE, verbose = TRUE) {
  
  predict.all <- FALSE
  num.trees <- object$num.trees
  type <- "response"					
  
  forest <- object$forest
  data <- object$data
  
  if (is.null(forest)) {
    stop("Error: No saved forest in unityfor object. Please set write.forest to TRUE when calling unityfor.")
  }
  
  var.imp <- object$variable.importance
  
  if (is.null(vars)) {
    
    if (is.null(var.imp)) {
      stop("If no variables are provided via 'vars', the fitted 'unityfor' object must contain unity VIM values ('$variable.importance'). Please refit using 'importance=\"unity\"'.'")
    }
    
    if (is.null(indvars)) {
      if (numvars > length(var.imp)) {
        numvars <- length(var.imp)
        warning(paste0("The value of 'numvars' was larger than the number of variables. --> numvars set to ", length(var.imp), "."))
      }
      repr.var.names <- names(sort(var.imp, decreasing=TRUE))[1:numvars]
    } else {
      if (any(!(indvars %in% 1:length(var.imp)))) {
        indvars_outofbound <- sort(indvars[any(!(indvars %in% 1:length(var.imp)))])
        if (length(setdiff(indvars, indvars_outofbound)) == 0) {
          stop("No elements of 'indvars' are within 1:length(variable.importance).")
        }
        warning(paste0("The following elements of 'indvars' were not within 1:length(variable.importance): ", paste(indvars_outofbound, collapse=", "), ". These are removed from 'indvars'."))
        indvars <- setdiff(indvars, indvars_outofbound)
      }
      repr.var.names <- names(sort(var.imp, decreasing=TRUE))[indvars]
    }
    
  } else {
    if (any(!(vars %in% forest$independent.variable.names))) {
      vars_notincluded <- vars[any(!(vars %in% forest$independent.variable.names))]
      if (length(setdiff(vars, vars_notincluded)) == 0) {
        stop("No elements of 'vars' are within the independent variables names of the unityfor object.")
      }
      warning(paste0("The following elements of 'vars' were not within the independent variables names of the unityfor object: ", paste(vars_notincluded, collapse=", "), ". These are removed from 'vars'."))
      vars <- setdiff(vars, vars_notincluded)
    }
    repr.var.names <- vars
  }
  
  
  
  se.method <- "infjack"
  seed <- NULL
  
  variable.names <- colnames(data)
  
  ## Check forest argument
  if (!inherits(object, "unityfor")) {
    stop("Error: Invalid class of input object.")
  }
  
  if (is.null(forest$dependent.varID) || is.null(forest$num.trees) ||
      is.null(forest$child.nodeIDs) || is.null(forest$split.varIDs) ||
      is.null(forest$split.values) || is.null(forest$independent.variable.names) ||
      is.null(forest$treetype)) {
    stop("Error: Invalid forest object.")
  }
  
  ## Convert to data matrix
  data.final <- data.matrix(data)
  
  ## Check missing values
  if (any(is.na(data.final))) {
    offending_columns <- colnames(data.final)[colSums(is.na(data.final)) > 0]
    stop("Missing data in columns: ",
         paste0(offending_columns, collapse = ", "), ".", call. = FALSE)
  }
  
  if (sum(!(forest$independent.variable.names %in% variable.names)) > 0) {
    stop("Error: One or more independent variables not found in data.")
  }
  
  ## Num threads
  ## Default 0 -> detect from system in C++.
  if (is.null(num.threads)) {
    num.threads = 0
  } else if (!is.numeric(num.threads) || num.threads < 0) {
    stop("Error: Invalid value for num.threads")
  }
  
  ## Seed
  if (is.null(seed)) {
    seed <- runif(1 , 0, .Machine$integer.max)
  }
  
  if (forest$treetype == "Classification") {
    treetype <- 1
  } else if (forest$treetype == "Regression") {
    treetype <- 3
  } else if (forest$treetype == "Probability estimation") {
    treetype <- 9
  } else {
    stop("Error: Unknown tree type.")
  }
  
  ## Defaults for variables not needed
  dependent.variable.name <- ""
  mtry <- 0
  importance <- 0
  min.node.size <- 0
  split.select.weights <- list(c(0, 0))
  use.split.select.weights <- FALSE
  always.split.variables <- c("0", "0")
  use.always.split.variables <- FALSE
  status.variable.name <- "status"
  prediction.mode <- TRUE
  write.forest <- FALSE
  replace <- TRUE
  probability <- FALSE
  unordered.factor.variables <- c("0", "0")
  use.unordered.factor.variables <- FALSE
  save.memory <- FALSE
  splitrule <- 1
  alpha <- 0
  minprop <- 0
  case.weights <- c(0, 0)
  use.case.weights <- FALSE
  class.weights <- c(0, 0)
  keep.inbag <- FALSE
  sample.fraction <- 1
  holdout <- FALSE
  num.random.splits <- 1
  order.snps <- FALSE
  oob.error <- FALSE
  max.depth <- 0
  inbag <- list(c(0,0))
  use.inbag <- FALSE
  nsplits <- 0 ## asdf
  proptry <- 0 ## asdf
  eim.mode <- 0
  prediction.type <- 1
  prop.best.splits <- object$prop.best.splits
  
  repr.tree.mode <- TRUE
  
  ## Use sparse matrix
  if ("dgCMatrix" %in% class(data.final)) {
    sparse.data <- data.final
    data.final <- matrix(c(0, 0))
    use.sparse.data <- TRUE
  } else {
    sparse.data <- Matrix::Matrix(matrix(c(0, 0)))
    use.sparse.data <- FALSE
  }
  
  ## Call divfor
  result <- divforCpp(treetype, dependent.variable.name, data.final, variable.names, mtry,
                      num.trees, verbose, seed, num.threads, write.forest, importance,
                      min.node.size, min_node_size_root=0, split.select.weights, use.split.select.weights,
                      always.split.variables, use.always.split.variables,
                      status.variable.name, prediction.mode, forest, snp_data=as.matrix(0), replace, probability,
                      unordered.factor.variables, use.unordered.factor.variables, save.memory, splitrule,
                      case.weights, use.case.weights, class.weights, 
                      predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, 
                      prediction.type, num.random.splits, sparse.data, use.sparse.data,
                      order.snps, oob.error, max.depth, max_depth_root=3, num_cand_trees=1000, inbag, use.inbag, nsplits, npairs=0, proptry, prop_var_root=0, divfortype=4, 
                      promispairs=list(0,0), eim_mode=0, metricind=numeric(0), prop.best.splits,
                      repr.tree.mode, repr.var.names)
  
  
  result$forest$covariate.levels <- forest$covariate.levels
  
  if (length(result) == 0) {
    stop("User interrupt or internal error.")
  }
  
  
  plots_and_rules <- list()
  for(i in seq(along=repr.var.names)) {
    
    plots_and_rules[[i]] <- .plot_representative_tree(i, result$forest, data, names(data)[forest$dependent.varID+1], repr.var.names, forest$independent.variable.names, plotit, highlight_relevant, box_plots, density_plots, add_split_line)
    
    if(plotit) {
      for (j in seq(along=plots_and_rules[[i]]$ps)) {
        print(plots_and_rules[[i]]$ps[[j]])
        if(j < length(plots_and_rules[[i]]$ps)) {
          readline(prompt="Press [enter] for next plot.")
        }
      }
    }
    
    if(i < length(repr.var.names) & plotit)
      readline(prompt="Press [enter] for next plot.")
  }
  
  rules <- lapply(plots_and_rules, function(x) x$rules)
  plots <- lapply(plots_and_rules, function(x) x$ps)
  
  names(rules) <- names(plots) <- repr.var.names
  
  
  ## Prepare results
  
  result$rules <- rules
  result$plots <- plots
  result$num.samples <- nrow(data.final)
  result$var.names <- repr.var.names
  result$treetype <- forest$treetype
  result$var.names.all <- forest$independent.variable.names
  result$predictions <- result$num.trees <- NULL
  
  
  # Reorder the result list, so that they items in the list have a meaningful
  # ordering:
  
  res_names_all <- c("rules", "plots", "var.names",  "var.names.all", "num.independent.variables", "num.samples", "treetype", "forest")
  res_names <- names(result)
  
  ind_cl <- which(res_names %in% res_names_all)
  
  res_names_sub <- res_names[ind_cl]
  res_names_all_sub <- res_names_all[res_names_all %in% res_names]
  
  reorderind <- as.numeric(factor(res_names_all_sub, levels=res_names_sub))
  
  new_order <- 1:length(result)
  new_order[ind_cl] <- ind_cl[reorderind]
  
  result <- result[new_order]
  
  
  
  class(result) <- "unityfor.reprTrees"
  return(result)
  
}




# Function for plotting one of the representative trees:
.plot_representative_tree <- function(treeind, forest, data, dep.var.name, var.names, var.names.all, plotit=TRUE, highlight_relevant=TRUE, box_plots=TRUE, density_plots=TRUE, add_split_line=TRUE) {
  
  # Subset the data to only contain the in-bag observations of the tree:
  inbag_ind <- which(forest$inbag.counts[[treeind]]!=0)
  data <- data[inbag_ind,]
  data[-ncol(data)] <- lapply(data[-ncol(data)], as.numeric)
  
  # Extract the information from the forest for the treeind-th tree:
  scoreval <- forest$score.values[[treeind]]
  nodeID_in_root <- forest$nodeID.in.root[[treeind]]
  child_left <- forest$child.nodeIDs[[treeind]][[1]] + 1
  child_right <- forest$child.nodeIDs[[treeind]][[2]] + 1
  split_varIDs <- forest$split.varIDs[[treeind]] + 1
  split_values <- forest$split.values[[treeind]]
  is_in_best <- forest$is.in.best[[treeind]]
  is_cat <- child_left != 1 & nodeID_in_root[child_left] != 0 & split_varIDs %in% which(sapply(forest$covariate.levels, length) > 0)
  
  # Determine the depth of the tree because this informs, how broad the first splits
  # will be plotted:
  
  depth_count <- 1
  curr_child_nodes <- c()
  curr_depths <- rep(99, length(nodeID_in_root))
  
  for(i in seq(along=nodeID_in_root)) {
    
    nodeID_in_root_temp <- nodeID_in_root[i]
    
    if ((i == 1) || (child_left[i] != 1 && nodeID_in_root[child_left[i]] != 0))
    {
      
      if (nodeID_in_root_temp %in% curr_child_nodes) {
        depth_count <- depth_count + 1
        curr_child_nodes <- c()
      }
      
      curr_child_nodes <- c(curr_child_nodes, nodeID_in_root[child_left[i]], nodeID_in_root[child_right[i]])
      
      curr_depths[i] <- depth_count
      
    }
    
  }
  
  depth_tree <- depth_count
  
  # Width of the first split:
  max_width <- 2^(depth_tree-1)
  
  # Names of categorical variables:
  cat_indicator <- sapply(forest$covariate.levels, length)>0
  var.names.cat <- var.names.all[cat_indicator]
  var.cat.levels <- forest$covariate.levels[cat_indicator]
  
  # This will contain the labels showing the class distributions at the nodes:
  y_label <- rep("", length(nodeID_in_root))
  
  # This will contain the vectors of indices of observations at each node:
  indices_nodes <- vector("list", length = length(nodeID_in_root))
  indices_nodes[[1]] <- 1:nrow(data)
  
  # First class distribution label:
  nums_raw <- table(data[indices_nodes[[1]], dep.var.name])
  freqs_raw <- nums_raw/sum(nums_raw)
  
  ytab <- round(freqs_raw, 2) # CHANGED: 3)
  y_label[1] <- paste(paste0(names(ytab), ": ", ytab), collapse=", ")
  
  
  # First line of the data.frame that will contain the horizontal lines showing
  # the splits in the plot:
  left_label <- format_split_label(split_var=var.names.all[split_varIDs[1]], split_val=split_values[1], 
                                      var.names.cat=var.names.cat, cov_levels=var.cat.levels, data=data, 
                                      ind_curr=indices_nodes[[1]], direction="left")
  right_label <- format_split_label(split_var=var.names.all[split_varIDs[1]], split_val=split_values[1], 
                                       var.names.cat=var.names.cat, cov_levels=var.cat.levels, data=data, 
                                       ind_curr=indices_nodes[[1]], direction="right")
  xlines_df <- data.frame(x = -max_width/2, xend = max_width/2, y = 0, yend = 0, left_label = left_label, 
                          right_label=right_label, 
                          scoreval=scoreval[1], is_in_best=is_in_best[1], ind_in_full_tree=1, highlight=1,
                          stringsAsFactors = FALSE)
  
  # First line of the data.frame that will contain the class distribution labels:
  y_label_df <- data.frame(x = 0, y = 0.9, y_label=y_label[1], ind_in_full_tree=1, highlight=1, stringsAsFactors = FALSE)
  
  
  # This will contain the class distributions in numerical format, will be
  # returned by the function:
  rules_dist <- vector("list", length = length(nodeID_in_root))
  
  rules_dist_new <- rbind(freqs_raw, nums_raw)
  rownames(rules_dist_new) <- c("frequencies", "numbers")
  
  rules_dist[[1]] <- rules_dist_new
  
  # This will contain the rules corresponding to the nodes in the forest:
  rules <- rep("root node", length(nodeID_in_root))
  
  
  # This loop cycles through the nodes in the tree and adds the the plot
  # information, namely the coordinates of the horizontal and vertical lines,
  # as well as the labels shown in the plot, and further visualization information:
  for(i in 2:length(nodeID_in_root)) {
    
    nodeID_in_root_temp <- nodeID_in_root[i]
    
    # Only the tree roots are plotted:
    if (nodeID_in_root[i] != 0)
    {
      
      parent_left <- which(child_left==i)
      ind_parent <- c(parent_left, which(child_right==i))
      
      if(length(parent_left) > 0) {
        indices_nodes[[i]] <- indices_nodes[[ind_parent]][data[indices_nodes[[ind_parent]],var.names.all[split_varIDs[ind_parent]]] <= split_values[ind_parent]]
      } else {
        indices_nodes[[i]] <- indices_nodes[[ind_parent]][data[indices_nodes[[ind_parent]],var.names.all[split_varIDs[ind_parent]]] > split_values[ind_parent]]
      }
      
      nums_raw <- table(data[indices_nodes[[i]], dep.var.name])
      freqs_raw <- nums_raw/sum(nums_raw)
      ytab <- round(freqs_raw, 2) # CHANGED: 3)
      y_label[i] <- paste(paste0(names(ytab), ": ", ytab), collapse=", ")
      
      rules_dist_new <- rbind(freqs_raw, nums_raw)
      rownames(rules_dist_new) <- c("frequencies", "numbers")
      
      rules_dist[[i]] <- rules_dist_new
      
      parent_line <- xlines_df[xlines_df$ind_in_full_tree==ind_parent,]    
      
      rules[i] <- paste0(rules[ind_parent], ", ", ifelse(length(parent_left) > 0, parent_line$left_label, parent_line$right_label))
      
      # If the node has child nodes in the tree root, add horizontal line (and class distribution information):
      if (child_left[i] != 1 && nodeID_in_root[child_left[i]] != 0)
      {
        
        if(length(parent_left) > 0) {
          x_new <- parent_line$x - max_width/(2*curr_depths[i])
          xend_new <- parent_line$x + max_width/(2*curr_depths[i])
        } else {
          x_new <- parent_line$xend - max_width/(2*curr_depths[i])
          xend_new <- parent_line$xend + max_width/(2*curr_depths[i])
        }
        
        left_label <- format_split_label(split_var=var.names.all[split_varIDs[i]], split_val=split_values[i], 
                                            var.names.cat=var.names.cat, cov_levels=var.cat.levels, data=data, 
                                            ind_curr=indices_nodes[[i]], direction="left")
        right_label <- format_split_label(split_var=var.names.all[split_varIDs[i]], split_val=split_values[i], 
                                             var.names.cat=var.names.cat, cov_levels=var.cat.levels, data=data, 
                                             ind_curr=indices_nodes[[i]], direction="right")
        
        new_row <- data.frame(x = x_new, xend = xend_new, y = -curr_depths[i]+1, yend = -curr_depths[i]+1, 
                              left_label = left_label, right_label=right_label, 
                              scoreval=scoreval[i], is_in_best=is_in_best[i], ind_in_full_tree=i, highlight=0, 
                              stringsAsFactors = FALSE)
        
        if (is_in_best[i] == 1 & var.names.all[split_varIDs[i]]==var.names[treeind]) {

          xlines_df$highlight[which(xlines_df$ind_in_full_tree == ind_parent)] <- 1
          y_label_df$highlight[which(y_label_df$ind_in_full_tree == ind_parent)] <- 1
          
          parent_temp <- ind_parent
          
          while (parent_temp != 1) {
            parent_left <- which(child_left==parent_temp)
            ind_parent <- c(parent_left, which(child_right==parent_temp))

            xlines_df$highlight[which(xlines_df$ind_in_full_tree == ind_parent)] <- 1
            y_label_df$highlight[which(y_label_df$ind_in_full_tree == ind_parent)] <- 1
            parent_temp <- ind_parent
          }
          
          new_row$highlight <- 1
          
        }
        
        xlines_df <- rbind(xlines_df, new_row)
        
        y_label_df <- rbind(y_label_df, data.frame(x = (new_row$x + new_row$xend)/2, y = new_row$y + 0.85, y_label=y_label[i], ind_in_full_tree=i, highlight=new_row$highlight))
        
      }
      else {
        
        # If the node does not have nodes in the tree root, just add class distribution information:
        if(length(parent_left) > 0) {
          x_new <- parent_line$x
        } else {
          x_new <- parent_line$xend
        }
        
        y_label_df <- rbind(y_label_df, data.frame(x = x_new, y = parent_line$y - 0.15, y_label=y_label[i], ind_in_full_tree=i, highlight=0))
      }
      
    }
    
  }
  
  
  # In the above loop, not all labels of the class distributions were marked as
  # highlighted, only the direct ancestors, for example only the right nodes,
  # but not the left ones, or the other way round.
  # Highlight the rest:
  
  highlight_inds <- which(y_label_df$highlight==1)[-1]
  
  for (i in seq(along=highlight_inds)) {
    temp_id <- y_label_df$ind_in_full_tree[highlight_inds[i]]
    if (length(which(child_right==temp_id)) > 0) {
      y_label_df$highlight[which(y_label_df$ind_in_full_tree==child_left[which(child_right==temp_id)])] <- 1
    } else {
      y_label_df$highlight[which(y_label_df$ind_in_full_tree==child_right[which(child_left==temp_id)])] <- 1
    }
  }
  
  # Add the name of the variable used for spltting:
  xlines_df$split_var <- var.names.all[split_varIDs[xlines_df$ind_in_full_tree]]
  
  # Mark the labels of the class distributions for the best splits:
  best_inds <- which(xlines_df$is_in_best & xlines_df$split_var==var.names[treeind]) #  xlines_df$ind_in_full_tree)#xlines_df$highlight==1)
  for(i in seq(along=best_inds)) {
    temp_id <- xlines_df$ind_in_full_tree[best_inds[i]]
    y_label_df$highlight[which(y_label_df$ind_in_full_tree==child_left[temp_id])] <- 1
    y_label_df$highlight[which(y_label_df$ind_in_full_tree==child_right[temp_id])] <- 1
  }
  
  # The widths of the horizontal lines are determined by scores that express whether
  # the corresponding variables are used more often in the best trees associated with
  # the variables than in the complete forest - however, for the variable for which
  # the representative tree was obtained this does not make sense because this variable
  # is obviously in all best trees that use this variable.
  # Therefore for the variable for which the representative tree was obtained we
  # use as width the mean score value from the score values of the splits in the
  # other variables:
  xlines_df$scoreval[xlines_df$split_var==var.names[treeind]] <- mean(xlines_df$scoreval[xlines_df$split_var!=var.names[treeind]])
  
  
  
  # Make the vertical lines using the information on the horizontal lines:
  ylines_df <- data.frame(x = 0, xend = 0, y = 0, yend = 0.8, highlight = 1)
  for(i in 2:nrow(xlines_df)) {
    x_new <- (xlines_df$x[i] + xlines_df$xend[i])/2
    ylines_df <- rbind(ylines_df, data.frame(x=x_new, xend=x_new, y=xlines_df$y[i], yend=xlines_df$y[i]+0.8, highlight = xlines_df$highlight[i]))
  }
  
  
  # Subset the rules vectors so that they only contain nodes from the tree roots:
  incl_ind <- which(sapply(rules_dist, length)!=0)
  
  rules <- rules[incl_ind]
  rules_dist <- rules_dist[incl_ind]
  
  rules <- gsub("root node, ", "", rules)
  
  names(rules_dist) <- rules
  
  
  ps <- list()
  
  # Make the plot:
  
  if (highlight_relevant) {
    
    y_label_df$color <- 2
    y_label_df$color[y_label_df$highlight==0] <- 4
    
    y_label_df$color <- factor(y_label_df$color)
    
    
    ylines_df$color <- 2
    ylines_df$color[ylines_df$highlight==0] <- 4
    
    ylines_df$color <- factor(ylines_df$color)
    
    
    incl_bool <- y_label_df$x %in% c(min(y_label_df$x), max(y_label_df$x))
    y_label_df_extr <- y_label_df[incl_bool,]
    y_label_df_nonextr <- y_label_df[!incl_bool,]
    
    xlines_df$color <- 1
    xlines_df$color[xlines_df$split_var==var.names[treeind]] <- 2
    xlines_df$color[xlines_df$highlight==0] <- 5
    xlines_df$color[xlines_df$split_var==var.names[treeind] & xlines_df$highlight==0] <- 4
    xlines_df$color[xlines_df$is_in_best==TRUE & xlines_df$split_var==var.names[treeind]] <- 3
    
    xlines_df$color <- factor(xlines_df$color)
    
    xlines_df_labels <- xlines_df
    xlines_df_labels$color <- 2
    xlines_df_labels$color[xlines_df_labels$highlight==0] <- 4
    
    xlines_df_labels$color <- factor(xlines_df_labels$color)
    
    p <- ggplot2::ggplot() + ggplot2::geom_segment(data=ylines_df, ggplot2::aes(x=.data$x, xend=.data$xend, y=.data$y, yend=.data$yend, color=.data$color), linewidth=2) +
      ggplot2::geom_segment(data=xlines_df, ggplot2::aes(x=.data$x, xend=.data$xend, y=.data$y, yend=.data$yend, linewidth=.data$scoreval, color=.data$color, linetype=factor(as.numeric(.data$color==3)))) +
      ggplot2::scale_color_manual(values = c("1" = "grey50", "2" = "black", "3" = "green3", "4" = "grey65", "5" = "grey90")) +
      ggplot2::scale_linetype_manual(values = c("0" = "solid", "1" = "twodash")) +
      ggplot2::geom_text(data = xlines_df_labels, ggplot2::aes(x = .data$x + 0.2*(.data$xend - .data$x), y = .data$y + 0.2, label = .data$left_label, color=.data$color), size = 4) +
      ggplot2::geom_text(data = xlines_df_labels, ggplot2::aes(x = .data$x + 0.8*(.data$xend - .data$x), y = .data$y + 0.2, label = .data$right_label, color=.data$color), size = 4) +
      ggplot2::geom_text(data = y_label_df_nonextr, ggplot2::aes(x = .data$x, y = .data$y, label = .data$y_label, color=.data$color), size = 4) +
      ggrepel::geom_text_repel(data = y_label_df_extr, ggplot2::aes(x = .data$x, y = .data$y, label = .data$y_label, color=.data$color), size = 4, direction = "x") + #, box.padding = .3, force = 1)
      ggplot2::ggtitle(var.names[treeind]) +
      ggplot2::theme_void() +
      ggplot2::theme(legend.position = "none",
            plot.title = ggplot2::element_text(size = 18, hjust = 0.5, margin = ggplot2::margin(t = 10, b = 10), face="bold") )
    
  } else {
    
    incl_bool <- y_label_df$x %in% c(min(y_label_df$x), max(y_label_df$x))
    y_label_df_extr <- y_label_df[incl_bool,]
    y_label_df_nonextr <- y_label_df[!incl_bool,]
    
    xlines_df$color <- 1
    xlines_df$color[xlines_df$split_var==var.names[treeind]] <- 2
    xlines_df$color[xlines_df$is_in_best==TRUE & xlines_df$split_var==var.names[treeind]] <- 3
    xlines_df$color <- factor(xlines_df$color)
    
    p <- ggplot2::ggplot() + ggplot2::geom_segment(data=ylines_df, ggplot2::aes(x=.data$x, xend=.data$xend, y=.data$y, yend=.data$yend), linewidth=2) +
      ggplot2::geom_segment(data=xlines_df, ggplot2::aes(x=.data$x, xend=.data$xend, y=.data$y, yend=.data$yend, linewidth=.data$scoreval, color=.data$color, linetype=factor(as.numeric(.data$color==3)))) +
      ggplot2::scale_color_manual(values = c("1" = "grey60", "2" = "black", "3" = "green3")) +
      ggplot2::scale_linetype_manual(values = c("0" = "solid", "1" = "twodash")) +
      ggplot2::geom_text(data = xlines_df, ggplot2::aes(x = .data$x + 0.2*(.data$xend - .data$x), y = .data$y + 0.2, label = .data$left_label), size = 4) +
      ggplot2::geom_text(data = xlines_df, ggplot2::aes(x = .data$x + 0.8*(.data$xend - .data$x), y = .data$y + 0.2, label = .data$right_label), size = 4) +
      ggplot2::geom_text(data = y_label_df_nonextr, ggplot2::aes(x = .data$x, y = .data$y, label = .data$y_label), size = 4) +
      ggrepel::geom_text_repel(data = y_label_df_extr, ggplot2::aes(x = .data$x, y = .data$y, label = .data$y_label), size = 4, direction = "x") + #, box.padding = .3, force = 1)
      ggplot2::ggtitle(var.names[treeind]) +
      ggplot2::theme_void() +
      ggplot2::theme(legend.position = "none",
            plot.title = ggplot2::element_text(size = 18, hjust = 0.5, margin = ggplot2::margin(t = 10, b = 10), face="bold") )
    
  }
  
  ps[[length(ps) + 1]] <- p
  names(ps)[length(ps)] <- "tree_plot"
  
  
  indices_best_splits <- xlines_df$ind_in_full_tree[xlines_df$is_in_best==1 & xlines_df$split_var==var.names[treeind]]
  
  
  if (box_plots | density_plots) {
    
    for (i in seq(along=indices_best_splits)) {
      
      index_temp <- indices_best_splits[i]
      
      data_best_split <- data[indices_nodes[[index_temp]],]
      split_var <- var.names.all[split_varIDs[index_temp]]
      split_value <- split_values[index_temp]
      split_label <- names(rules_dist)[which(y_label_df$ind_in_full_tree == index_temp)]
      
      counts <- make_count_labels(data_best_split, dep.var.name, split_var)
      
      if (split_label=="root node")
        plot_title <- paste0("Marginal influence of ", split_var)
      else
        plot_title <- paste0("Influence of ", split_var, " for subgroup '", split_label, "'")
      
      xtemp <- data_best_split[,split_var]
      ytemp <- data_best_split[,dep.var.name]
      
      if (split_var %in% var.names.cat) {
        levelstemp <- var.cat.levels[[which(var.names.cat==split_var)]]
        xtemp <- factor(levelstemp[xtemp], levels=levelstemp)
      }
      
      if (!inherits(xtemp, "factor") && length(unique(xtemp)) > 2) {
        if (box_plots) {
          p <- plotVarBoxplotUFO(x=xtemp, y=ytemp, counts=counts, split_value=split_value, x_label=dep.var.name, y_label=split_var, plot_title=plot_title, add_split_line=add_split_line)
          ps[[length(ps) + 1]] <- p
          names(ps)[length(ps)] <- ifelse(length(indices_best_splits) > 1, paste0("box_plot_", i), "box_plot")
        }
        
        if (density_plots) {
          p <- plotVarDensityUFO(x=xtemp, y=ytemp, counts=counts, split_value=split_value, x_label=split_var, y_label="", legend_title=dep.var.name, plot_title=plot_title, add_split_line=add_split_line)
          ps[[length(ps) + 1]] <- p
          names(ps)[length(ps)] <- ifelse(length(indices_best_splits) > 1, paste0("density_plot_", i), "density_plot")
        }
      } else {
        p <- plotVarBarplotUFO(x=xtemp, y=ytemp, split_value=split_value, x_label=split_var, legend_title=dep.var.name, plot_title=plot_title, add_split_line=add_split_line)
        ps[[length(ps) + 1]] <- p
        names(ps)[length(ps)] <- ifelse(length(indices_best_splits) > 1, paste0("bar_plot_", i), "bar_plot")
      }
      
    }
    
  }
  
  return(list(rules=rules_dist, ps=ps))
}



# Function that formats the split values so that they can be better visualized
# in the plot:
format_number <- function(x) {
  abs_x <- abs(x)
  
  if (abs_x >= 1000) {
    return(formatC(x, format = "f", digits = 0))
  } else if (abs_x >= 10) {
    return(formatC(x, format = "f", digits = 1))
  } else if (abs_x >= 1) {
    return(formatC(x, format = "f", digits = 2))
  } else if (abs_x >= 0.1) {
    return(formatC(x, format = "f", digits = 3))
  } else if (abs_x >= 0.01) {
    return(formatC(x, format = "f", digits = 4))
  } else if (abs_x >= 0.001) {
    return(formatC(x, format = "f", digits = 5))
  } else if (abs_x >= 1e-4) {
    return(formatC(x, format = "f", digits = 6))
  } else {
    return(formatC(x, format = "e", digits = 2))
  }
}


# Function used for labeling the splits:
format_split_label <- function(split_var, split_val, var.names.cat, cov_levels, data, ind_curr, direction=c("left", "right")) {
  if (!(split_var %in% var.names.cat)) {
    op_temp <- ifelse(direction=="left", " \u2264 ", " > ")
    return(paste0(split_var, op_temp, format_number(split_val)))
  } else {
    inds_temp <- sort(unique(data[ind_curr, split_var]))
    if (direction == "left") {
      inds_node <- inds_temp[inds_temp <= split_val]
    } else {
      inds_node <- inds_temp[inds_temp > split_val]
    }
    categ_node <- cov_levels[[which(var.names.cat==split_var)]][inds_node]
    return(paste0(split_var, ":  ", paste(categ_node, collapse=", ")))
  }
}

# Compute counts and dynamic label positions:
get_counts_with_y <- function(df, yvar) {
  dplyr::summarise(
    dplyr::group_by(df, .data[[yvar]]),
    n = dplyr::n(),
    y = max(.data[[yvar]], na.rm = TRUE) + 0.05 * diff(range(.data[[yvar]], na.rm = TRUE)),
    .groups = "drop"
  )
}
#get_counts_with_y <- function(df, yvar) {
#  dplyr::summarise(
#    dplyr::group_by(df, y),
#    n = dplyr::n(),
#    y = max(df[[yvar]], na.rm = TRUE) +
#        0.05 * diff(range(df[[yvar]], na.rm = TRUE)),
#    .groups = "drop"
#  )
#}

make_count_labels <- function(dataset, dep.var.name, split_var) {
  dep <- dataset[[dep.var.name]]
  y   <- dataset[[split_var]]
  
  res <- by(y, dep, function(vals) {
    n    <- length(vals)
    rng  <- range(vals, na.rm = TRUE)
    ymax <- max(vals, na.rm = TRUE)
    ypos <- ymax + 0.05 * diff(rng)
    data.frame(n = n, y = ypos)
  })
  
  counts <- do.call(rbind, res)
  counts$category <- factor(rownames(counts), levels = levels(dep))
  counts$label <- paste0("n = ", counts$n)
  rownames(counts) <- NULL
  counts
}


plotVarBoxplotUFO <- function(x, y, counts, split_value, x_label="", y_label="", plot_title="", add_split_line=TRUE) {
  
  # Create a boxplot for a numeric variable:
  
  if (inherits(x, "numeric")) {
    
    plotdata <- data.frame(x=x, y=y)
    
    p <- ggplot2::ggplot(plotdata, ggplot2::aes(x =.data$y, y=.data$x)) +
      ggplot2::geom_text(data = counts, ggplot2::aes(x = .data$category, y = .data$y, label = .data$label), inherit.aes = FALSE, color = "grey40", size = 5) +
      ggplot2::geom_boxplot() +
      ggplot2::theme_bw() +
      ggplot2::theme(
        axis.text.x = ggplot2::element_text(color = "black", size = 15),
        axis.text.y = ggplot2::element_text(color = "black", size = 11),
        axis.title = ggplot2::element_text(size = 15),
        plot.title = ggplot2::element_text(size = 15)
      )
    
    if (add_split_line) 
      p <- p + ggplot2::geom_hline(yintercept = split_value, linetype = "dashed", color = "green3", linewidth = 1)
    
  }
  
  # Create a density plot for a factor variable:
  
  if (inherits(x, "ordered") || inherits(x, "factor")) {
    
    if (inherits(x, "factor"))
      warning("The plot is likely not meaningful because the variable is an unordered factor.")
    
    x_levels <- levels(x)[levels(x) %in% unique(x)]
    
    # For plotting, the factor variable is transformed to a continuous variable:
    x <- as.numeric(x)
    
    plotdata <- data.frame(x=x, y=y)
    
    x_unique_sorted <- sort(unique(x))
    
    # Boxplot:
    
    p <- ggplot2::ggplot(plotdata, ggplot2::aes(x =.data$y, y=.data$x)) +
      ggplot2::geom_text(data = counts, ggplot2::aes(x = .data$category, y = .data$y, label = .data$label), inherit.aes = FALSE, color = "grey40", size = 5) +
      ggplot2::geom_boxplot() +
      ggplot2::scale_y_continuous(breaks=x_unique_sorted, labels=x_levels) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        axis.text.x = ggplot2::element_text(color = "black", size = 15),
        axis.text.y = ggplot2::element_text(color = "black", size = 11),
        axis.title = ggplot2::element_text(size = 15),
        plot.title = ggplot2::element_text(size = 15)
      )
    
    
    if (add_split_line) 
      p <- p + ggplot2::geom_hline(yintercept = split_value, linetype = "dashed", color = "green3", linewidth = 1)
    
  }
  
  # Add labels to the plot if provided:
  
  if (x_label=="")
    p <- p + ggplot2::theme(axis.title.x=ggplot2::element_blank())
  else
    p <- p + ggplot2::ylab(x_label)
  
  if (y_label=="")
    p <- p + ggplot2::xlab("class")
  else
    p <- p + ggplot2::xlab(y_label)
  
  if (plot_title!="")
    p <- p + ggplot2::ggtitle(plot_title)
  
  p
  
}


plotVarDensityUFO <- function(x, y, counts, split_value, x_label="", y_label="", legend_title="", plot_title="", add_split_line=TRUE) {
  
  classtab <- table(y)
  
  # The densities are plotted only for classes with at least two observations:
  levels_to_keep <- names(classtab[classtab >= 2])
  
  filterbool <- y %in% levels_to_keep
  
  x <- x[filterbool]
  y <- y[filterbool]
  
  if (length(unique(x)) < length(unique(y)))
    stop("The number of unique variable values must be at least as large as the number of classes.")
  
  allclasses <- levels(y)[levels(y) %in% unique(y)]
  
  classtab <- classtab[classtab >= 2]
  classprob <- classtab/sum(classtab)
  
  # The maximum number of different colors used. If the number of classes is larger
  # than this, the different classes are differentiated visually using both
  # colors and line types:
  nmax <- min(c(length(allclasses), 7))
  
  colors <- scales::hue_pal()(nmax)
  
  if (length(allclasses) == nmax) {
    colorsvec <- colors
	if (length(allclasses) == 2) {
	linetypesvec <- c("solid", "dashed")
	}
	else if (length(allclasses) == 3) {
	linetypesvec <- c("solid", "longdash", "dotdash")
	}
	else {
    linetypesvec <- rep("solid", length=length(colorsvec))
	}
  } else {
    colorsvec <- rep(colors, length=length(allclasses))
  
    linetypesvec <- rep(c("solid", "longdash", "dotdash"), each=nmax)[1:length(colorsvec)]
    linetypesvec <- c(linetypesvec, rep("dotdash", times=length(colorsvec) - length(linetypesvec)))
  }
  ##linetypesvec <- c("solid", "dashed")
  ##colorsvec <- c("#F8766D", "#00BFC4")
  
  
  # Create a density plot for a numeric variable:
  
  if (inherits(x, "numeric")) {
    
    denstemps <- list()
    
    for(i in seq(along=allclasses)) {
      xtemp <- x[y==allclasses[i]]
      
      denstemp <- density(xtemp)
      denstemp <- data.frame(x=denstemp$x, y=denstemp$y)
      # The density values are scaled by the class sizes:
      denstemp$y <- denstemp$y*classprob[i]
      denstemps[[i]] <- denstemp
    }
    
    plotdata <- do.call("rbind", denstemps)
    plotdata$class <- factor(rep(allclasses, times=sapply(denstemps, nrow)), levels=allclasses)
    
    pointdata <- data.frame(x=x, class=y)
    pointdata$class <- droplevels(pointdata$class)
    
    # If there are more than 1000 observations, the rug plot on the lower margin
    # only shows a random subset of 1000 observations:
    if (nrow(pointdata) > 1000) {
      pointdata <- pointdata[sample(1:nrow(pointdata), size=1000),]
    }
    
    
    coords <- sapply(denstemps, function(x) c(x$x[which.max(x$y)], x$y[which.max(x$y)]))
    coords[2,] <- coords[2,] + mean(coords[2,])*0.05
    
    counts$x1 <- coords[1,]
    counts$x2 <- coords[2,]
    
    p <- ggplot2::ggplot(plotdata, ggplot2::aes(x=.data$x, color=.data$class, linetype=.data$class)) + ggplot2::theme_bw() + ggplot2::geom_line(ggplot2::aes(y=.data$y)) +
      ggplot2::geom_text(data = counts, ggplot2::aes(x = .data$x1, y = .data$x2, label = paste0("n = ", .data$n)), color = "grey40", size = 5, inherit.aes = FALSE) +
      ggplot2::scale_color_manual(values=colorsvec) + ggplot2::scale_linetype_manual(values = linetypesvec) +
      ggplot2::ylab("(scaled) density") + ggplot2::geom_rug(data=pointdata, sides="b") +
      ggplot2::theme(
        axis.text = ggplot2::element_text(color = "black", size = 11),
        axis.title = ggplot2::element_text(size = 15),
        plot.title = ggplot2::element_text(size = 15),
        legend.title = ggplot2::element_text(size = 15),
        legend.text = ggplot2::element_text(size = 15)
      ) + ggplot2::labs(color = legend_title, linetype=legend_title)
    
    if (add_split_line) 
      p <- p + ggplot2::geom_vline(xintercept = split_value, linetype = "dashed", color = "green3", linewidth = 1)
    
  }
  
  # Create a density plot for a factor variable:
  
  if (inherits(x, "ordered") || inherits(x, "factor")) {
    
    if (inherits(x, "factor"))
      warning("The plot is likely not meaningful because the variable is an unordered factor..")
    
    x_levels <- levels(x)[levels(x) %in% unique(x)]
    
    # For plotting, the factor variable is transformed to a continuous variable:
    x <- as.numeric(x)
    
    denstemps <- list()
    
    for(i in seq(along=allclasses)) {
      xtemp <- x[y==allclasses[i]]
      
      denstemp <- density(xtemp)
      denstemp <- data.frame(x=denstemp$x, y=denstemp$y)
      denstemp$y <- denstemp$y*classprob[i]
      denstemps[[i]] <- denstemp
    }
    
    plotdata <- do.call("rbind", denstemps)
    plotdata$class <- factor(rep(allclasses, times=sapply(denstemps, nrow)), levels=allclasses)
    
    if (x_label=="")
      xlabadd <- ggplot2::theme(axis.title.x=ggplot2::element_blank())
    else
      xlabadd <- ggplot2::xlab(x_label)
    
    x_unique_sorted <- sort(unique(x))
    
    coords <- sapply(denstemps, function(x) c(x$x[which.max(x$y)], x$y[which.max(x$y)]))
    coords[2,] <- coords[2,] + mean(coords[2,])*0.05
    
    counts$x1 <- coords[1,]
    counts$x2 <- coords[2,]
    
    p <- ggplot2::ggplot(plotdata, ggplot2::aes(x=.data$x, y=.data$y, color=.data$class, linetype=.data$class)) + ggplot2::theme_bw() + ggplot2::geom_line() + 
      ggplot2::geom_text(data = counts, ggplot2::aes(x = .data$x1, y = .data$x2, label = paste0("n = ", .data$n)), color = "grey40", size = 5, inherit.aes = FALSE) +
      ggplot2::scale_color_manual(values=colorsvec) + ggplot2::scale_linetype_manual(values = linetypesvec) +
      # The labels of the categories of the variable are added to the x-axis:
      ggplot2::scale_x_continuous(breaks=x_unique_sorted, labels=x_levels) +
      ggplot2::ylab("density") +
      ggplot2::theme(axis.text.x = ggplot2::element_text(color = "black", size = 13, angle = 90, vjust = 0.5, hjust = 1),
            axis.text.y = ggplot2::element_text(color = "black", size = 11),
            axis.title = ggplot2::element_text(size = 15),
            plot.title = ggplot2::element_text(size = 15),
            legend.title = ggplot2::element_text(size = 15),
            legend.text = ggplot2::element_text(size = 15)
      ) + ggplot2::labs(color = legend_title, linetype=legend_title)
    
    if (add_split_line) 
      p <- p + ggplot2::geom_vline(xintercept = split_value, linetype = "dashed", color = "green3", linewidth = 1)
    
  }
  
  
  # Add labels to the plot if provided:
  
  if (x_label=="")
    p <- p + ggplot2::theme(axis.title.x=ggplot2::element_blank())
  else
    p <- p + ggplot2::xlab(x_label)
  
  if (y_label!="")
    p <- p + ggplot2::labs(colour=y_label, linetype=y_label)
  
  if (plot_title!="")
    p <- p + ggplot2::ggtitle(plot_title)
  
  return(p)
  
}

plotVarBarplotUFO <- function(x, y, split_value, x_label="", legend_title="", plot_title="", add_split_line=TRUE) {
  
  # Prepare counts and relative frequencies
  
  # Build a temporary data frame
  df <- data.frame(x = x, y = y)
  
  # Compute counts and relative frequencies
  counts <- dplyr::count(df, x, y, .drop = FALSE)
  counts <- dplyr::group_by(counts, x)
  counts <- dplyr::mutate(
    counts,
    prop  = .data$n / sum(.data$n),
    label = paste0("n = ", .data$n)
  )
  counts <- dplyr::ungroup(counts)
  
  # Plot
  p <- ggplot2::ggplot(counts, ggplot2::aes(x = .data$x, y = .data$prop, fill = .data$y)) +
    ggplot2::geom_col() +
    ggplot2::geom_text(
      ggplot2::aes(label = .data$label),
      position = ggplot2::position_stack(vjust = 0.5),
      size = 5
    ) +
    ggplot2::scale_y_continuous(labels = scales::percent_format()) +
    ggplot2::labs(x = x_label, y = "Relative frequency", fill = legend_title) +
    ggplot2::ggtitle(plot_title) +
    ggplot2::theme_bw() + 
    ggplot2::theme(
      axis.text.x = ggplot2::element_text(color = "black", size = 15),
      axis.text.y = ggplot2::element_text(color = "black", size = 11),
      axis.title = ggplot2::element_text(size = 15),
      plot.title = ggplot2::element_text(size = 15),
      legend.title = ggplot2::element_text(size = 15),
      legend.text = ggplot2::element_text(size = 15)
    )
  
  if (add_split_line) 
    p <- p + ggplot2::geom_vline(xintercept = split_value, linetype = "dashed", color = "green3", linewidth = 1)
  
  return(p)
}
