#' Read texts
#'
#' Read texts and create a \code{keyATM_docs} object, which is a list of texts.
#'
#' 
#' @param texts input. keyATM takes quanteda dfm (dgCMatrix), data.frame, \pkg{tibble} tbl_df, or a vector of file paths.
#' @param encoding character. Only used when \code{texts} is a vector of file paths. Default is \code{UTF-8}.
#' @param check logical. If \code{TRUE}, check whether there is anything wrong with the structure of texts. Default is \code{TRUE}.
#'
#' @return a list whose elements are splitted texts. The length of the list equals to the number of documents.
#'
#' @examples
#' \dontrun{
#'  # Use quanteda dfm
#'  keyATM_docs <- keyATM_read(texts = quanteda_dfm) 
#'   
#'  # Use data.frame or tibble (texts should be stored in a column named `text`)
#'  keyATM_docs <- keyATM_read(texts = data_frame_object) 
#'  keyATM_docs <- keyATM_read(texts = tibble_object) 
#' 
#'  # Use a vector that stores full paths to the text files 
#'  files <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE) 
#'  keyATM_docs <- keyATM_read(texts = files) 
#' 
#' }
#' @import magrittr
#' @importFrom rlang .data
#' @export
keyATM_read <- function(texts, encoding = "UTF-8", check = TRUE)
{

  # Detect input
  if ("tbl" %in% class(texts)) {
    message("Using tibble.")
    text_dfm <- NULL
    files <- NULL
    text_df <- texts
  } else if ("data.frame" %in% class(texts)) {
    message("Using data.frame.")
    text_dfm <- NULL
    files <- NULL
    text_df <- tibble::as_tibble(texts)
  } else if (class(texts) == "dfm" | "dgCMatrix" %in% class(texts)) {
    message("Using quanteda dfm.")
    text_dfm <- texts
    files <- NULL
    text_df <- NULL
  } else if (class(texts) == "character") {
    warning("Reading from files. Please make sure files are preprocessed.", immediate. = TRUE)
    text_dfm <- NULL
    files <- texts
    text_df <- NULL
    message(paste0("Encoding: ", encoding))
  } else {
    stop("Check `texts` argument.\n
         It can take quanteda dfm, data.frame, tibble, and a vector of characters.")  
  }

  # Read texts

  # If you have quanteda object
  if (!is.null(text_dfm)) {
    vocabulary <- colnames(text_dfm)
    W_raw <- list()
    W_raw <- read_dfm_cpp(text_dfm, W_raw, vocabulary)
  } else {
    ## preprocess each text
    # Use files <- list.files(doc_folder, pattern = "txt", full.names = TRUE) when you pass
    if (is.null(text_df)) {
      text_df <- tibble::tibble(text = unlist(lapply(files,
                                                 function(x)
                                                 { 
                                                     paste0(readLines(x, encoding = encoding),
                                                           collapse = "\n") 
                                                 })))
    }
    text_df <- text_df %>% dplyr::mutate(text_split = stringr::str_split(.data$text, pattern = " "))

    # extract splitted text and create a list
    W_raw <- text_df %>% dplyr::pull(.data$text_split)
  }

  
  # check whether there is nothing wrong with the structure of texts
  if (check) {
    check_vocabulary(unique(unlist(W_raw, use.names = FALSE, recursive = FALSE))) 
    doc_index <- get_doc_index(W_raw, check = TRUE)
  }

  # Assign class
  class(W_raw) <- c("keyATM_docs", class(W_raw))

  return(W_raw)
}


#' @noRd
#' @export
print.keyATM_docs <- function(x, ...)
{
  cat(paste0("keyATM_docs object of ",
                 length(x), " documents",
                 ".\n"
                )
      )
}


#' @noRd
#' @export
summary.keyATM_docs <- function(object, ...)
{
  doc_len <- sapply(object, length)
  cat(paste0("keyATM_docs object of: ",
              length(object), " documents",
              ".\n",
              "Length of documents:",
              "\n  Avg: ", round(mean(doc_len), 3),
              "\n  Min: ", round(min(doc_len), 3),
              "\n  Max: ", round(max(doc_len), 3),
              "\n   SD: ", round(stats::sd(doc_len), 3),
              "\nNumber of unique words: ", length(unique(unlist(object, use.names = FALSE, recursive = FALSE))),
              "\n"
             )  
         )
}


#' Visualize keywords
#'
#' Visualize the proportion of keywords in the documents.
#'
#' @param docs a keyATM_docs object, generated by \code{keyATM_read()} function
#' @param keywords a list of keywords
#' @param prune logical. If \code{TRUE}, prune keywords that do not appear in `docs`. Default is \code{TRUE}.
#' @param label_size the size of keyword labels in the output plot. Default is \code{3.2}.
#' @return keyATM_fig object
#' @examples
#' \dontrun{
#'  # Prepare a keyATM_docs object
#'  keyATM_docs <- keyATM_read(input)
#'   
#'  # Keywords are in a list  
#'  keywords <- list(Education = c("education", "child", "student"),
#'                   Health    = c("public", "health", "program"))
#'
#'  # Visualize keywords
#'  keyATM_viz <- visualize_keywords(keyATM_docs, keywords)
#'
#'  # View a figure
#'  keyATM_viz
#' 
#'  # Save a figure 
#'  save_fig(keyATM_viz, filename)
#' }
#' @import magrittr
#' @import ggplot2
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
visualize_keywords <- function(docs, keywords, prune = TRUE, label_size = 3.2)
{
  # Check type
  check_arg_type(docs, "keyATM_docs", "Please use `keyATM_read()` to read texts.")
  check_arg_type(keywords, "list")
  c <- lapply(keywords, function(x){check_arg_type(x, "character")})

  unlisted <- unlist(docs, recursive = FALSE, use.names = FALSE)

  # Check keywords
  keywords <- check_keywords(unique(unlisted), keywords, prune)

  # Organize data
  unnested_data <- tibble::tibble(text_split = unlisted)
  totalwords <- nrow(unnested_data)

  unnested_data %>%
    dplyr::rename(Word = .data$text_split) %>%
    dplyr::group_by(.data$Word) %>%
    dplyr::summarize(WordCount = dplyr::n()) %>%
    dplyr::ungroup(.data) %>%
    dplyr::mutate(`Proportion(%)` = round(.data$WordCount / totalwords * 100, 3)) %>%
    dplyr::arrange(dplyr::desc(.data$WordCount)) %>%
    dplyr::mutate(Ranking = 1:(dplyr::n())) -> data

  keywords <- lapply(keywords, function(x) {unlist(strsplit(x," "))})
  ext_k <- length(keywords)
  max_num_words <- max(unlist(lapply(keywords, function(x) {length(x)}), use.names = FALSE))

  # Make keywords_df
  keywords_df <- data.frame(Topic = 1, Word = 1)
  tnames <- names(keywords)
  for (k in 1:ext_k) {
    words <- keywords[[k]]
    numwords <- length(words)
    if (is.null(tnames)) {
      topicname <- paste0("Topic", k)
    } else {
      topicname <- paste0(k, "_", tnames[k]) 
    }

    for (w in 1:numwords) {
      keywords_df <- rbind(keywords_df, data.frame(Topic = topicname, Word = words[w]))
    }
  }
  keywords_df <- keywords_df[2:nrow(keywords_df), ]
  keywords_df$Topic <- factor(keywords_df$Topic, levels = unique(keywords_df$Topic))

  dplyr::right_join(data, keywords_df, by = "Word") %>%
    dplyr::group_by(.data$Topic) %>%
    dplyr::arrange(dplyr::desc(.data$WordCount)) %>%
    dplyr::mutate(Ranking  =  1:(dplyr::n())) %>%
    dplyr::arrange(.data$Topic, .data$Ranking) -> temp

  # Visualize
  visualize_keywords <- 
    ggplot(temp, aes(x = .data$Ranking, y = .data$`Proportion(%)`, colour = .data$Topic)) +
      geom_line() +
      geom_point() +
      ggrepel::geom_label_repel(aes(label = .data$Word), size = label_size,
                       box.padding = 0.20, label.padding = 0.12,
                       arrow = arrow(angle = 10, length = unit(0.10, "inches"),
                                   ends = "last", type = "closed"),
                       show.legend = FALSE) +
      scale_x_continuous(breaks = 1:max_num_words) +
      xlab("Ranking") + ylab("Proportion (%)") +
      theme_bw()

  keyATM_viz <- list(figure = visualize_keywords, values = temp, keywords = keywords)
  class(keyATM_viz) <- c("keyATM_fig", "keyATM_viz", class(keyATM_viz))
  
  return(keyATM_viz)
}


check_keywords <- function(unique_words, keywords, prune)
{
  # Prune keywords that do not appear in the corpus
  keywords_flat <- unlist(keywords, use.names = FALSE, recursive = FALSE)
  non_existent <- keywords_flat[!keywords_flat %in% unique_words]

  if (prune) {
    # Prune keywords 
    if (length(non_existent) != 0) {
     if (length(non_existent) == 1) {
       warning("A keyword will be pruned because it does not appear in documents: ",
               paste(non_existent, collapse = ", "), immediate. = TRUE)
     } else {
       warning("Keywords will be pruned because they do not appear in documents: ",
               paste(non_existent, collapse = ", "), immediate. = TRUE)
     }
    }

    keywords <- lapply(keywords,
                       function(x) {
                          x[!x %in% non_existent] 
                       })

  } else {

    # Raise error 
    if (length(non_existent) != 0) {
     if (length(non_existent) == 1) {
       stop("A keyword not found in texts: ", paste(non_existent, collapse = ", "))
     } else {
       stop("Keywords not found in texts: ", paste(non_existent, collapse = ", "))
     }
    } 

  }

  # Check there is at least one keywords in each topic
  num_keywords <- unlist(lapply(keywords, length))
  check_zero <- which(as.vector(num_keywords) == 0)

  if (length(check_zero) != 0) {
    zero_names <- names(keywords)[check_zero]
    stop(paste0("All keywords are pruned. Please check: ", paste(zero_names, collapse = ", ")))
  }

  return(keywords)
}


get_doc_index <- function(docs, check = FALSE)
{
  lapply(docs, length) %>% unlist(use.names = FALSE) -> len
  index <- 1:length(docs)
  nonzero_index <- index[index[len != 0]]
  zero_index <- index[index[len == 0]]
  if (length(zero_index) != 0) {
    if (check) {
      warning("Number of documents with 0 length: ", length(zero_index), "\n",
              "This may cause invalid covariates or time index.", "\n",
              "Please review the preprocessing steps.", "\n",
              "Document index to check: ", paste(zero_index, collapse = ", "),
              immediate. = TRUE)     
    } else {
      warning("Number of documents dropped because of 0 length: ", length(zero_index), "\n",
              "Document index to check: ", paste(zero_index, collapse = ", "),
              immediate. = TRUE) 
    }
  }
  return(nonzero_index)
}


#' Fit a keyATM model 
#' 
#' keyATM_fit is wrapped by keyATM() and weightedLDA()
#' @keywords internal
keyATM_fit <- function(docs, model, no_keyword_topics,
                       keywords = list(), model_settings = list(),
                       priors = list(), options = list()) 
{
  ##
  ## Check
  ##

  # Check type
  check_arg_type(docs, "keyATM_docs", "Please use `keyATM_read()` to read texts.")
  if (!is.integer(no_keyword_topics) & !is.numeric(no_keyword_topics))
    stop("`no_keyword_topics` is neigher numeric nor integer.")

  no_keyword_topics <- as.integer(no_keyword_topics)

  if (!model %in% c("base", "cov", "hmm", "lda", "ldacov", "ldahmm", "label")) {
    stop("Please select a correct model.")  
  }

  info <- list(
                models_keyATM = c("base", "cov", "hmm", "label"),
                models_lda = c("lda", "ldacov", "ldahmm")
              )
  keywords <- check_arg(keywords, "keywords", model, info)

  # Get Info
  info$use_doc_index <- get_doc_index(docs)
  docs <- docs[info$use_doc_index]
  info$num_doc <- length(docs)
  info$keyword_k <- length(keywords)
  info$total_k <- length(keywords) + no_keyword_topics
  info$num_core <- max(1, parallel::detectCores(all.tests = FALSE, logical = TRUE) - 2L)

  # Set default values
  model_settings <- check_arg(model_settings, "model_settings", model, info)
  priors <- check_arg(priors, "priors", model, info)
  options <- check_arg(options, "options", model, info)
  info$parallel_init <- options$parallel_init


  ##
  ## Initialization
  ##
  message("Initializing the model...")
  set.seed(options$seed)

  # W
  info$wd_names <- unique(unlist(docs, use.names = FALSE, recursive = FALSE))
  check_vocabulary(info$wd_names)

  info$wd_map <- myhashmap(info$wd_names, 1:length(info$wd_names) - 1L)

  if (info$parallel_init) {
    W <- parallel::mclapply(docs, function(x) { myhashmap_getvec(info$wd_map, x) }, mc.cores = info$num_core)
  } else {
    W <- lapply(docs, function(x) { myhashmap_getvec(info$wd_map, x) })
  }

  # Check keywords
  keywords <- check_keywords(info$wd_names, keywords, options$prune)

  keywords_raw <- keywords  # keep raw keywords (not word_id)
  keywords_id <- lapply(keywords, function(x) { myhashmap_getvec(info$wd_map, x) })
  info$keywords_id <- unlist(keywords_id, use.names = FALSE, recursive = FALSE)

  # Assign S and Z
  if (model %in% info$models_keyATM) {
    res <- make_sz_key(W, keywords, info)
    S <- res$S
    Z <- res$Z
  } else {
    # LDA based models
    res <- make_sz_lda(W, info)
    S <- res$S
    Z <- res$Z
  }
  rm(res)

  # Organize
  stored_values <- list(vocab_weights = rep(-1, length(info$wd_names)),
                        doc_index = info$use_doc_index)

  if (model %in% c("base", "lda", "label")) {
    if (options$estimate_alpha)
      stored_values$alpha_iter <- list()  
  }

  if (model %in% c("hmm", "ldahmm")) {
    options$estimate_alpha <- 1
    stored_values$alpha_iter <- list()  
  }


  if (model %in% c("cov", "ldacov")) {
    stored_values$Lambda_iter <- list()
  }

  if (model %in% c("hmm", "ldahmm")) {
    stored_values$R_iter <- list()

    if (options$store_transition_matrix) {
      stored_values$P_iter <- list()  
    }
  }

  if (model %in% info$models_keyATM) {
    if (options$store_pi)
      stored_values$pi_vectors <- list() 
  }

  if (options$store_theta)
    stored_values$Z_tables <- list()

  key_model <- list(
                    W = W, Z = Z, S = S,
                    model = abb_model_name(model),
                    keywords = keywords_id, keywords_raw = keywords_raw,
                    no_keyword_topics = no_keyword_topics,
                    vocab = info$wd_names,
                    model_settings = model_settings,
                    priors = priors,
                    options = options,
                    stored_values = stored_values,
                    model_fit = list(),
                    call = match.call()
                   )

  rm(info)
  class(key_model) <- c("keyATM_model", model, class(key_model))

  if (options$iterations == 0) {
    message("`options$iterations` is 0. keyATM returns an initialized object.")  
    return(key_model)
  }


  ##
  ## Fitting
  ##
  return(fitting_models(key_model, model, options))
}


fitting_models <- function(key_model, model, options)
{
  message(paste0("Fitting the model. ", options$iterations, " iterations..."))
  set.seed(options$seed)

  if (model == "base") {
    key_model <- keyATM_fit_base(key_model, iter = options$iterations)
  } else if (model == "cov") {
    key_model <- keyATM_fit_cov(key_model, iter = options$iteration)
  } else if (model == "hmm") {
    key_model <- keyATM_fit_HMM(key_model, iter = options$iteration)  
  } else if (model == "lda") {
    key_model <- keyATM_fit_LDA(key_model, iter = options$iteration)
  } else if (model == "ldacov") {
    key_model <- keyATM_fit_LDAcov(key_model, iter = options$iteration)
  } else if (model == "ldahmm") {
    key_model <- keyATM_fit_LDAHMM(key_model, iter = options$iteration)  
  } else if (model == "label") {
    key_model <- keyATM_fit_label(key_model, iter = options$iteration)
  } else {
    stop("Please check `mode`.")
  }

  class(key_model) <- c("keyATM_fitted", class(key_model))
  return(key_model)
}


#' @noRd
#' @export
print.keyATM_model <- function(x, ...)
{
  cat(
      paste0(
             "keyATM_model object for the ",
             x$model,
             " model.\nThis is an initialized object without fitting the model.",
             "\n"
            )
     )
}


check_arg <- function(obj, name, model, info = list())
{
  if (name == "keywords") {
    return(check_arg_keywords(obj, model, info))
  }

  if (name == "model_settings") {
    return(check_arg_model_settings(obj, model, info))  
  }

  if (name == "priors") {
    return(check_arg_priors(obj, model, info))  
  }

  if (name == "options") {
    return(check_arg_options(obj, model, info))  
  }

  if (name == "vb_options") {
    return(check_arg_vboptions(obj, model, info))  
  }
}


check_arg_keywords <- function(keywords, model, info)
{
  check_arg_type(keywords, "list")

  if (length(keywords) == 0 & model %in% info$models_keyATM) {
    stop("Please provide keywords.")  
  }

  if (length(keywords) != 0 & model %in% info$models_lda) {
    stop("This model does not take keywords.")  
  }


  # Name of keywords topic
  if (model %in% info$models_keyATM) {
    c <- lapply(keywords, function(x) {check_arg_type(x, "character")})
    if (is.null(names(keywords))) {
      names(keywords)  <- paste0(1:length(keywords))
    } else {
      names(keywords)  <- paste0(1:length(keywords), "_", names(keywords))
    }
  }

  return(keywords)
}


show_unused_arguments <- function(obj, name, allowed_arguments)
{
  unused_input <- names(obj)[! names(obj) %in% allowed_arguments]
  if (length(unused_input) != 0)
    stop(paste0(
                "keyATM doesn't recognize some of the arguments ",
                "in ", name, ": ",
                paste(unused_input, collapse=", ")
               )
        )
}


check_arg_model_settings <- function(obj, model, info)
{
  check_arg_type(obj, "list")
  allowed_arguments <- c()

  if (model %in% c("base", "lda", "hmm", "ldahmm", "label")) {
    # Slice Sampling Settings
    if (is.null(obj$slice_min)) {
      obj$slice_min <- 1e-9 
    } else {
      if (!is.numeric(obj$slice_min)) {
        stop("`model_settings$slice_min` should be a numeric value.")
      }
      if (obj$slice_min <= 0) {
        stop("`model_settings$slice_min` should be a positive value.")
      }
    }

    if (is.null(obj$slice_max)) {
      obj$slice_max <- 100 
    } else {
      if (!is.numeric(obj$slice_max)) {
        stop("`model_settings$slice_max` should be a numeric value.")
      }
      if (obj$slice_max <= 0) {
        stop("`model_settings$slice_max` should be a positive value.")
      }
    }

    allowed_arguments <- c(allowed_arguments, "slice_min", "slice_max")
  }

  # check model settings for covariate model
  if (model %in% c("cov", "ldacov")) { 
     if (is.null(obj$covariates_data)) {
      stop("Please provide `obj$covariates_data`.")  
    }

    obj$covariates_data <- as.data.frame(obj$covariates_data)[info$use_doc_index, , drop = FALSE]

    if (nrow(obj$covariates_data) != info$num_doc) {
      stop("The row of `model_settings$covariates_data` should be the same as the number of documents.")  
    }

    if (sum(is.na(obj$covariates_data)) != 0) {
      stop("Covariate data should not contain missing values.")
    }

    if (is.null(obj$covariates_formula)) {
      warning("`covariates_formula` is not provided. keyATM uses the matrix as it is.", immediate. = TRUE)
      obj$covariates_formula <- NULL  # do not need to change the matrix
      obj$covariates_data_use <- obj$covariates_data
    } else if (is.formula(obj$covariates_formula)) {
      message("Convert covariates data using `model_settings$covariates_formula`.")
      obj$covariates_data_use <- stats::model.matrix(obj$covariates_formula,
                                                     as.data.frame(obj$covariates_data))
    } else {
      stop("Check `model_settings$covariates_formula`.")  
    }


    if (is.null(obj$standardize)) {
      obj$standardize <- TRUE 
    }

    # Check if it works as a valid regression 
    temp <- as.data.frame(obj$covariates_data_use)
    temp$y <- stats::rnorm(nrow(obj$covariates_data_use))

    if ("(Intercept)" %in% colnames(obj$covariates_data_use)) {
      fit <- stats::lm(y ~ 0 + ., data = temp)  # data.frame already includes the intercept
      if (NA %in% fit$coefficients) {
        stop("Covariates are invalid.")    
      }    
    } else {
      fit <- stats::lm(y ~ ., data = temp)  # data.frame does not have an itercept
      if (NA %in% fit$coefficients) {
        stop("Covariates are invalid.")    
      }
    }

    if (obj$standardize) {
      standardize <- function(x) {return((x - mean(x)) / stats::sd(x))}

      if ("(Intercept)" %in% colnames(obj$covariates_data_use)) {
        # Do not standardize the intercept
        colnames_keep <- colnames(obj$covariates_data_use)
        obj$covariates_data_use <- cbind(obj$covariates_data_use[, 1, drop=FALSE],
                                         apply(as.matrix(obj$covariates_data_use[, -1]), 2,
                                               standardize) 
                                        )
        colnames(obj$covariates_data_use) <- colnames_keep
      } else {
        obj$covariates_data_use <- apply(obj$covariates_data_use, 2, standardize)
      }
    }

    # Slice Sampling Settings
    if (is.null(obj$slice_min)) {
      obj$slice_min <- -5.0 
    } else {
      if (!is.numeric(obj$slice_min)) {
        stop("`model_settings$slice_min` should be a numeric value.")
      }
    }

    if (is.null(obj$slice_max)) {
      obj$slice_max <- 5.0 
    } else {
      if (!is.numeric(obj$slice_max)) {
        stop("`model_settings$slice_max` should be a numeric value.")
      }
    }

    # MH option
    if (is.null(obj$mh_use)) {
      obj$mh_use <- 0 
    } else {
      obj$mh_use <- as.integer(obj$mh_use)
      if (!obj$mh_use %in% c(0, 1)) {
        stop("`model_settings$mh_use` should be TRUE/FALSE (0/1)") 
      }
    }

    allowed_arguments <- c(allowed_arguments, "covariates_data", "covariates_data_use",
                           "slice_min", "slice_max", "mh_use",
                           "covariates_formula", "standardize", "info")
  }  # cov model end

  # check model settings for dynamic model
  if (model %in% c("hmm", "ldahmm")) {
    if (is.null(obj$num_states)) {
      stop("`model_settings$num_states` is not provided.")  
    }

    if (is.null(obj$time_index)) {
      stop("`model_settings$time_index` is not provided.")
    }

    obj$time_index <- obj$time_index[info$use_doc_index]

    if (length(obj$time_index) != info$num_doc) {
      stop("The length of the `model_settings$time_index` does not match with the number of documents.")  
    }
    
    if (min(obj$time_index) != 1 | max(obj$time_index) > info$num_doc) {
      stop("`model_settings$time_index` should start from 1 and not exceed the number of documents.")
    }

    if (max(obj$time_index) < obj$num_states)
      stop("`model_settings$num_states` should not exceed the maximum of `model_settings$time_index`.")


    check <- unique(obj$time_index[2:length(obj$time_index)] - stats::lag(obj$time_index)[2:length(obj$time_index)])
    if (sum(!unique(check) %in% c(0,1)) != 0)
      stop("`model_settings$num_states` does not increment by 1.")

    obj$time_index <- as.integer(obj$time_index)

    allowed_arguments <- c(allowed_arguments, "num_states", "time_index")
    
  }

  # check model settings for label model
  if (model %in% "label") {
    if (is.null(obj$labels)) {
      stop("`model_settings$labels` is not provided.")
    }

    obj$labels <- obj$labels[info$use_doc_index]

    if (length(obj$labels) != info$num_doc) {
      stop("The length of `model_settings$labels` does not match with the number of documents")
    }
    if (max(obj$labels, na.rm = TRUE) > info$keyword_k | min(obj$labels, na.rm = TRUE) <= 0) {
      stop("`model_settings$labels` must only contain integer values less than the total number of the keyword topics for labeled documents and `NA` should be assigned to non-labeled documents.")
    }
   
    obj$labels[is.na(obj$labels)] <- 0 # insert -1 to NA values
    obj$labels <- as.integer(obj$labels) - 1L  # index starts from 0 in C++, you do not need to worry about NA here
    

    if (!isTRUE(all(obj$labels == floor(obj$labels)))) {
      stop("`model_settings$labels` must only contain integer values for labeled documents and `NA` should be assigned to non-labeled documents")
    }

    allowed_arguments <- c(allowed_arguments, "labels")
  }

  show_unused_arguments(obj, "`model_settings`", allowed_arguments)
  return(obj)
}


check_arg_priors <- function(obj, model, info)
{
  check_arg_type(obj, "list")
  # Base arguments
  allowed_arguments <- c("beta")

  # prior of pi
  if (model %in% info$models_keyATM) {
    if (is.null(obj$gamma)) {
      obj$gamma <- matrix(1.0, nrow = info$total_k, ncol = 2)  
    }

    if (!is.null(obj$gamma)) {
      if (dim(obj$gamma)[1] != info$total_k)  
        stop("Check the dimension of `priors$gamma`")
      if (dim(obj$gamma)[2] != 2)  
        stop("Check the dimension of `priors$gamma`")
    }

    if (info$keyword_k < info$total_k) {
      # Regular topics are used in keyATM models
      # Priors for regular topics should be 0
      if (sum(obj$gamma[(info$keyword_k+1):info$total_k, ]) != 0) {
        obj$gamma[(info$keyword_k+1):info$total_k, ] <- 0
      }
    }

    allowed_arguments <- c(allowed_arguments, "gamma")
  }


  # beta
  if (!"beta" %in% names(obj)) {
    obj$beta <- 0.01  
  }

  if (model %in% info$models_keyATM) {
    if (!"beta_s" %in% names(obj)) {
      obj$beta_s <- 0.1  
    }  
    allowed_arguments <- c(allowed_arguments, "beta_s")
  }


  # alpha
  if (model %in% c("base", "lda", "label")) {
    if (is.null(obj$alpha)) {
      obj$alpha <- rep(1/info$total_k, info$total_k)
    }
    if (length(obj$alpha) != info$total_k) {
      stop("Starting alpha must be a vector of length ", info$total_k)
    }
    allowed_arguments <- c(allowed_arguments, "alpha")
  
  }

  show_unused_arguments(obj, "`priors`", allowed_arguments)
  return(obj)
}


check_arg_options <- function(obj, model, info)
{
  check_arg_type(obj, "list")
  allowed_arguments <- c("seed", "llk_per", "thinning",
                         "iterations", "verbose",
                         "use_weights", "weights_type", 
                         "prune", "store_theta", "slice_shape",
                         "parallel_init")

  # llk_per
  if (is.null(obj$llk_per))
    obj$llk_per <- 10L

  if (!is.numeric(obj$llk_per) | obj$llk_per < 0 | obj$llk_per%%1 != 0) {
      stop("An invalid value in `options$llk_per`")  
  }


  # verbose
  if (is.null(obj$verbose)) {
    obj$verbose <- 0L 
  } else {
    obj$verbose <- as.integer(obj$verbose)
    if (!obj$verbose %in% c(0, 1)) {
      stop("An invalid value in `options$verbose`")  
    }
  }

  # thinning
  if (is.null(obj$thinning))
    obj$thinning <- 5L

  if (!is.numeric(obj$thinning) | obj$thinning < 0| obj$thinning%%1 != 0) {
      stop("An invalid value in `options$thinning`")  
  }

  # seed
  if (is.null(obj$seed))
    obj$seed <- floor(stats::runif(1)*1e5)

  # iterations
  if (is.null(obj$iterations))
    obj$iterations <- 1500L
  if (!is.numeric(obj$iterations) | obj$iterations < 0| obj$iterations%%1 != 0) {
      stop("An invalid value in `options$iterations`")  
  }

  # Store theta
  if (is.null(obj$store_theta)) {
    obj$store_theta <- 0L
  } else {
    obj$store_theta <- as.integer(obj$store_theta)  
    if (!obj$store_theta %in% c(0, 1)) {
      stop("An invalid value in `options$store_theta`")  
    }
  }

  # Store pi
  if (model %in% info$models_keyATM) {
    if (is.null(obj$store_pi)) {
      obj$store_pi <- 0L
    } else {
      obj$store_pi <- as.integer(obj$store_pi)  
      if (!obj$store_pi %in% c(0, 1)) {
        stop("An invalid value in `options$store_theta`")  
      }
    }
    allowed_arguments <- c(allowed_arguments, "store_pi")
  }


  # Estimate alpha
  if (model %in% c("base", "lda", "label")) {
    if (is.null(obj$estimate_alpha)) {
      obj$estimate_alpha <- 1L
    } else {
      obj$estimate_alpha <- as.integer(obj$estimate_alpha)  
      if (!obj$estimate_alpha %in% c(0, 1)) {
        stop("An invalid value in `options$estimate_alpha`")  
      }

    }
    allowed_arguments <- c(allowed_arguments, "estimate_alpha")
  }
  
  # Slice shape
  if (is.null(obj$slice_shape)) {
    # parameter for slice sampling
    obj$slice_shape <- 1.2
  }
  if (!is.numeric(obj$slice_shape) | obj$slice_shape < 0) {
      stop("An invalid value in `options$slice_shape`")  
  }

  # Use weights
  if (is.null(obj$use_weights)) {
    obj$use_weights <- 1L 
  } else {
    obj$use_weights <- as.integer(obj$use_weights)
    if (!obj$use_weights %in% c(0, 1)) {
      stop("An invalid value in `options$use_weights`")  
    }
  }

  # Type of the weights
  if (is.null(obj$weights_type)) {
    obj$weights_type <- "information-theory" 
  } else {
    if (!obj$weights_type %in% c("information-theory", "information-theory-normalized", 
                                 "inv-freq", "inv-freq-normalized")) 
    {
      stop("An invalid value in `options$weights_type`") 
    }
  }

  # Prune keywords
  if (is.null(obj$prune)) {
    obj$prune <- 1L 
  } else {
    obj$prune <- as.integer(obj$prune)
    if (!obj$prune %in% c(0, 1)) {
      stop("An invalid value in `options$prune`")  
    }
  }

  # Store transition matrix in Dynamic models
  if (model %in% c("hmm", "ldahmm")) {
    if (is.null(obj$store_transition_matrix)) {
      obj$store_transition_matrix <- 0L  
    }
    if (!obj$store_transition_matrix %in% c(0, 1)) {
      stop("An invalid value in `options$store_transition_matrix`")  
    }
    allowed_arguments <- c(allowed_arguments, "store_transition_matrix")
  }

  # Use parallel function in initialization
  if (!"parallel_init" %in% names(obj)) {
    obj$parallel_init <- FALSE 
  } else {
    if (!obj$parallel_init %in% c(0, 1, FALSE, TRUE)) {
      stop("`obj$parallel_init` should be TRUE/FALSE") 
    }
  }
  allowed_arguments <- c(allowed_arguments, "parallel_init")

  # Check unused arguments
  show_unused_arguments(obj, "`options`", allowed_arguments)
  return(obj)
}


check_vocabulary <- function(vocab)
{
  if (" " %in% vocab) {
    stop("A space is recognized as a vocabulary. Please remove an empty document or consider using quanteda::dfm.")  
  }

  if ("" %in% vocab) {
    stop('A blank `""` is recognized as a vocabulary. Please review preprocessing steps.')  
  }

  if (sum(stringr::str_detect(vocab, "^[:upper:]+$")) != 0) {
    warning('Upper case letters are used. Please review preprocessing steps.', immediate. = TRUE)  
  }

  if (sum(stringr::str_detect(vocab, "\t")) != 0) {
    warning('Tab is detected in the vocabulary. Please review preprocessing steps.', immediate. = TRUE)
  }

  if (sum(stringr::str_detect(vocab, "\n")) != 0) {
    warning('A line break is detected in the vocabulary. Please review preprocessing steps.', immediate. = TRUE)
  }
}


make_sz_key <- function(W, keywords, info)
{
  # zs_assigner maps keywords to category ids
  key_wdids <- info$keywords_id
  cat_ids <- rep(1:(info$keyword_k) - 1L, unlist(lapply(keywords, length)))

  if (length(key_wdids) == length(unique(key_wdids))) {
    #
    # No keyword appears more than once
    #
    zs_assigner <- myhashmap_keyint(as.integer(key_wdids), as.integer(cat_ids))

    # if the word is a keyword, assign the appropriate (0 start) Z, else a random Z
    topicvec <- 1:(info$total_k) - 1L
    make_z <- function(s, topicvec) {
      zz <- myhashmap_getvec_keyint(zs_assigner, s)
      zz[is.na(zz)] <- sample(topicvec, sum(is.na(zz)), replace = TRUE)
      return(zz)
    }

  } else {
    #
    # Some keywords appear multiple times
    #
    keys_df <- data.frame(wid = key_wdids, cat = cat_ids)
    keys_char <- sapply(unique(key_wdids),
                        function(x) {
                          paste(as.character(keys_df[keys_df$wid == x, "cat"]), collapse=",")
                        })
    zs_hashtable <- myhashmap_keyint(as.integer(unique(key_wdids)), keys_char)

    zs_assigner <- function(s) {
      topic <- myhashmap_getvec_keyint(zs_hashtable, s)
      topic <- strsplit(as.character(topic), split=",")  # this function should take a character vector
      topic <- lapply(topic, sample, 1)
      topic <- as.integer(unlist(topic, use.names = FALSE))
      return(topic)
    }

    # if the word is a seed, assign the appropriate (0 start) Z, else a random Z
    topicvec <- 1:(info$total_k) - 1L
    make_z <- function(s, topicvec) {
      zz <- zs_assigner(s) # if it is a seed word, we already know the topic
      zz[is.na(zz)] <- sample(topicvec, sum(is.na(zz)), replace = TRUE)
      return(zz)
    }
  }

  ## ss indicates whether the word comes from a seed topic-word distribution or not
  make_s <- function(s) {
    key <- as.numeric(s %in% key_wdids) # 1 if they're a seed
    # Use s structure
    s[key == 0] <- 0L # non-keyword words have s = 0
    s[key == 1] <- sample(0:1, length(s[key == 1]), prob = c(0.3, 0.7), replace = TRUE)
      # keywords have x = 1 probabilistically
    return(s)
  }

  if (info$parallel_init) {
    S <- parallel::mclapply(W, make_s, mc.cores = info$num_core, mc.set.seed = FALSE)
    Z <- parallel::mclapply(W, make_z, topicvec, mc.cores = info$num_core, mc.set.seed = FALSE)
  } else {
    S <- lapply(W, make_s)
    Z <- lapply(W, make_z, topicvec)
  }
 
 
  return(list(S = S, Z = Z))
}


make_sz_lda <- function(W, info)
{
  topicvec <- 1:(info$total_k) - 1L
  make_z <- function(x, topicvec) {
    zz <- sample(topicvec,
                 length(x),
                 replace = TRUE)
    return(as.integer(zz))
  }  

  if (info$parallel_init) {
    Z <- parallel::mclapply(W, make_z, topicvec, mc.cores = info$num_core, mc.set.seed = FALSE)
  } else {
    Z <- lapply(W, make_z, topicvec)
  }

  return(list(S = list(), Z = Z))
}
