## Aims
#
# 1. Reproduce upsetplot
# 2. Expand the function of upsetplot
#   2.1 Allow non-unique intersections
#   2.2 More applicable subplot types (boxplot, points), useful to show subsets properties


## Structures
#
# 1. plot and subplots
# 2. themes
# 3. data processing



## (PART) Plot

#' Plot a upset plot
#' 
#' This function generate a upset plot by creating a composite plot which contains subplots generated by ggplot2.
#' 
#' @title upsetplot2 
#' @param list a list of sets
#' @param nintersects number of intersects. If NULL, all intersections will show.
#' @param order.intersect.by one of 'size' or 'name'
#' @param order.set.by one of 'size' or 'name'
#' @return an upset plot
#' 
#' @export
#' 
#' @examples
#'  list = list(A = sample(LETTERS, 20),
#'              B = sample(LETTERS, 22),
#'              C = sample(LETTERS, 14),
#'              D = sample(LETTERS, 30, replace = TRUE))
#'  upset_plot(list)
#'  upset_plot(list, order.intersect.by = "name")
#'  upset_plot(list, nintersects = 6)
upset_plot = function(list,    # use 'a_' prefix for all `aplot` objects
                     nintersects = NULL,
                     order.intersect.by = c("size", "name"),
                     order.set.by = c("size", "name")){
  # process arguments
  if (is.null(names(list))){  # we need a named list
    names(list) = paste("Set", seq_along(list), sep = "_")
  }
  order.intersect.by = match.arg(order.intersect.by)
  order.set.by = match.arg(order.set.by)
  
  # subplot main
  data = tidy_main_subsets(list,
                           nintersects = nintersects,
                           order.intersect.by = order.intersect.by,
                           order.set.by = order.set.by)
  p_main = upsetplot_main(data$main_data)
  
  # subplot top
  p_top = upsetplot_top(data$top_data)
  
  # subplot left
  p_left = upsetplot_left(data$left_data)

  # combine into a plot
  pp = aplot::insert_top(p_main, p_top, height=4) |>
    aplot::insert_left(p_left, width=.2)
  class(pp) <- c("a_upset_plot", class(pp))

  return(pp)
}

upsetplot_main = function(data){
  ggplot2::ggplot(data, aes(.data$id, .data$set)) +
    ggplot2::geom_point(size = 4, color = "grey30", na.rm = FALSE) +
    ggplot2::geom_path(aes(group = .data$id), size = 1.5, color = "grey30", na.rm = FALSE) +
    ggplot2::labs(x = "Set Intersection", y = "") +
    theme_upset_main()
}

upsetplot_top = function(data){
  ggplot2::ggplot(data, aes(.data$id, .data$size)) +
    ggplot2::geom_col() +
    ggplot2::labs(x = "", y = "Intersection Size") +
    theme_upset_top()
}

upsetplot_left = function(data){
  ggplot2::ggplot(data, aes(x = .data$size, y = .data$set)) +
    ggplot2::geom_col(orientation = "y") +
    ggplot2::scale_y_discrete(position = "right") +
    ggplot2::scale_x_reverse() +
    ggplot2::labs(x = "Set Size") +
    theme_upset_left()
}

## (PART) Theme

theme_upset_main = function(){
  ggplot2::theme_bw() +
    ggplot2::theme(
      axis.title.y = element_blank(),
      axis.ticks.x.bottom = element_blank(),
      axis.text.x.bottom = element_blank(),
      # panel.border = element_blank(),
      plot.margin = margin(t = -20)
    )
}

theme_upset_top = function(){
  ggplot2::theme_bw() +
    ggplot2::theme(
      axis.ticks.x.bottom = element_blank(),
      axis.text.x.bottom = element_blank(),
      # panel.border = element_blank(),
      plot.margin = margin(b = -20, unit = "pt")
    )
}

theme_upset_left = function(){
   ggplot2::theme_bw() +
    ggplot2::theme(
      axis.ticks.y = ggplot2::element_blank(),
      axis.title.y = ggplot2::element_blank(),
      axis.text.y = ggplot2::element_blank(),
      panel.border = ggplot2::element_blank(),
      panel.grid.major = ggplot2::element_blank(),
      plot.margin = margin(r = -20)
    )
}


## (PART) retrieve tidy data from primary subset datasets

##' @importFrom forcats as_factor
tidy_main_subsets = function(list,
                             nintersects,
                             order.intersect.by,
                             order.set.by){
  data = get_all_subsets(list)
  set_name = names(list)

  # top data
  top_data = data |>
    dplyr::select(c('id','name', 'item','size')) |>
    dplyr::mutate(id = forcats::fct_reorder(.data$id, .data[[order.intersect.by]], .desc = TRUE))

  # left data
  left_data = dplyr::tibble(set = set_name,
                            name = set_name,
                            size = sapply(list, length)) |>
    dplyr::mutate(set = forcats::fct_reorder(.data$set, .data[[order.set.by]], .desc = TRUE))

  # main data
  main_data = data |>
    dplyr::select(c("id")) |>
    dplyr::mutate(set = .data$id) |>
    tidyr::separate_longer_delim(.data$set, delim = "/")
  main_data$set = factor(set_name[as.integer(main_data$set)],
                         levels = levels(left_data$set))
 
  # filter intersections
  if (!is.null(nintersects)){
    keep_id = utils::head(levels(top_data$id), nintersects)
    main_data = main_data |> dplyr::filter(.data$id %in% keep_id)
    top_data = top_data |> dplyr::filter(.data$id %in% keep_id)
  }
  
  # return result as a list
  ret = list(top_data = top_data,
             left_data = left_data,
             main_data = main_data)
  return(ret)
}

## (PART) build primary data of subsets

#' Get the items/names/ids of subsets from a named list
#'
#' @param list a named list
#' @param name_separator default is /
#'
#' @return a tibble
#' @export
#'
#' @examples
#' list = list(A = sample(LETTERS, 20),
#'             B = sample(LETTERS, 22),
#'             C = sample(LETTERS, 24),
#'             D = sample(LETTERS, 30, replace = TRUE))
#' get_all_subsets(list)
get_all_subsets = function(list, name_separator = "/"){
  df = dplyr::tibble(
    id = get_all_subsets_ids(list, sep = name_separator),
    name = get_all_subsets_names(list, sep = name_separator),
    item = get_all_subsets_items(list)
  )
  df$size = sapply(df$item, length)
  return(df)
}

get_all_subsets_items <- function(list){
  n = length(list)
  c = combinations(n)
  lapply(c, function(i) overlap(list,i))
}

get_all_subsets_names <- function(list, sep = "/"){
  n = length(list)
  set_name = names(list)
  c = combinations(n)
  sapply(c, function(i) paste0(set_name[i], collapse = sep))
}

##' @importFrom yulab.utils combinations
get_all_subsets_ids <- function(list, sep = "/"){
  n <- length(list)
  c <- combinations(n)
  sapply(c, function(i) paste0(i, collapse = sep))
}

## (PART) subset calculations: the basics

##' @importFrom purrr reduce
overlap <- function(list, idx){
  slice1 <- list[idx]
  slice2 <- list[-idx]

  if (length(slice1) > 1L){
    overlap = slice1 |> purrr::reduce(intersect)
  } else if (length(slice1) == 1L){
    overlap <- unlist(slice1)
  } else {
    overlap <- NULL
  }

  if (length(slice2) > 1L){
    outmember <- slice2 |> purrr::reduce(union)
  } else if (length(slice2) == 1L){
    outmember <- unlist(slice2)
  } else {
    outmember <- NULL
  }

  setdiff(overlap, outmember)
}
