#' Calculate Sample Sizes for a Two-Stage Trial Design.
#'
#' Computes the required sample sizes for the interim and final analyses in a two-stage trial design, based on pre-specified operating characteristics such as desired power, type I error rate assumptions.
#' @importFrom truncdist rtrunc
#' @param median.1 Numeric. The overall median survival time for SOC.
#' @param median.2 Numeric. The overall median survival time for the experimental arm.
#' @param gprior.E_1 Optional. A numeric vector of length two, representing the shape and scale
#'   parameters of the inverse-gamma prior for \eqn{\mu_0}, the mean survival time before the separation time.
#'   If \code{NULL}, the default is \code{c(4, 3 / (log(2) * median.1))}.
#'
#' @param gprior.E_2 Optional. A numeric vector of length two, representing the inverse-gamma prior
#'   for \eqn{\mu_1}, the mean survival time after the separation time.
#'   If \code{NULL}, the default is \code{c(4, 6 / (log(2) * median.1))}.
#'
#' @param L Numeric. The lower bound of the separation time window.
#'
#' @param U Numeric. The upper bound of the separation time window.
#'
#' @param S_likely Numeric. The most likely separation time. Defaults to the midpoint of \code{L} and \code{U}.
#'
#' @param trunc.para Numeric vector of length two. Specifies the shape and scale parameters of
#'   the truncated gamma prior for the separation time.
#'
#' @param err1 Numeric. The pre-specified type I error rate (e.g., \code{0.1}).
#'
#' @param err2 Numeric. The pre-specified type II error rate (i.e., \eqn{\beta}, e.g., \code{0.2}).
#'
#' @param FUP Numeric. The duration of follow-up per patient (in months or years). Default is 6.
#'
#' @param rate Numeric. The recruitment rate (e.g., patients per month or year).
#'
#' @param weight Numeric. Weight given to the expected sample size under \eqn{H_0}. Default is \code{0.5}.
#'
#' @param method Character. Specifies which method to use:
#'   \code{"optimal"} (default) or \code{"suboptimal"}. Determines how the sample size is calculated.
#'
#' @param earlystop_prob Optional. Numeric. If specified, ensures that the probability of early
#'   stopping under \eqn{H_0} is at least this value, while also maintaining the power at or above \eqn{1 - \beta}
#'   when the separation time is no more than \code{U}.
#' @param seed Optional integer. If provided, sets the seed for reproducibility.
#' @return A numeric vector of length 2. The first element is the required sample size for interim analysis, and the second element is the total sample size for final analysis in a two-stage design.
#' @examples
#' # Define design parameters
#' median.1 <- 4       # Median survival for standard-of-care
#' median.2 <- 6       # Median survival for experimental arm
#' L <- 2              # Lower bound of separation time
#' U <- 2.5            # Upper bound of separation time
#' S_likely <- 2.28    # Most likely separation time
#' rate <- 6           # Accrual rate (patients/month)
#' FUP <- 6            # Follow-up duration (months)
#' err1 <- 0.1         # Type I error
#' err2 <- 0.15        # Type II error
#'
#' \donttest{
#'   Two_stage_sample_size(
#'     median.1 = median.1,
#'     median.2 = median.2,
#'     L = L,
#'     U = U,
#'     S_likely = S_likely,
#'     err1 = err1,
#'     err2 = err2,
#'     FUP = FUP,
#'     rate = rate,
#'     weight = 0.5,
#'     method = "optimal"
#'   )
#' }
#' @export
Two_stage_sample_size=function(median.1, median.2,gprior.E_1=NULL, gprior.E_2=NULL,L,U,S_likely=(L+U)/2,trunc.para=c(1,1),err1,err2, FUP, rate, weight=0.5,method="optimal",earlystop_prob = NULL,seed=NULL){
  median_inuse <- function(median_0, median_1, S) {
    median.0 <- median_0

    # \tilde{mu}_1
    if (median_0 < S) {
      median.1 <- median_0
    } else {
      median.1 <- (median_1-S)/(1-S/median_0)
    }

    return(c(median.0,median.1))
  }
  median_1=median_inuse(median.1,median.2,S=S_likely)[1]
  median_2=median_inuse(median.1,median.2,S=S_likely)[2]
  if(is.null(gprior.E_1)){
    gprior.E_1=c(4,3*median_1/log(2))
  }
  if(is.null(gprior.E_2)){
    gprior.E_2=c(4,6*median_1/log(2))
  }
  Event_prop <- function(nmax, t, median_1, median_2, FUP = 6, accrual_rate, n_sim = 10000) {
    lambda1 <- log(2) / median_1
    lambda2 <- log(2) / median_2

    # Store mean event proportions for each sample size
    event_proportions <- numeric(nmax)

    for (sim in 1:n_sim) {
      # Step 1: Generate arrival times (fixed across all n)
      wait_times <- rexp(nmax, rate = accrual_rate)
      arrival_times <- cumsum(wait_times)

      # Step 2: Iterate over sample sizes from 1 to nmax
      for (n in 1:nmax) {
        # Determine interim index and cutoff time dynamically for each n
        interim_index <- n
        cutoff_time <- arrival_times[interim_index] + FUP

        # Generate event times for the current sample size
        event_times <- generate_pe(n, t, lambda1, lambda2)

        # Step 3: Determine observed times and censoring
        observed_times <- numeric(n)
        event_occurred <- numeric(n)

        for (i in 1:n) {
          if (arrival_times[i] + event_times[i] <= cutoff_time) {
            # Event occurs before the cutoff
            observed_times[i] <- event_times[i]
            event_occurred[i] <- 1
          } else {
            # Event is censored
            observed_times[i] <- cutoff_time - arrival_times[i]
            event_occurred[i] <- 0
          }
        }

        # Step 4: Accumulate event proportions
        event_proportions[n] <- event_proportions[n] + sum(event_occurred) / n
      }
    }

    # Step 5: Compute mean event proportions
    mean_event_proportions <- event_proportions / n_sim



    # Return the list of event proportions
    return(list(
      event_proportions = mean_event_proportions
    ))
  }

  Initial_sample_size <- function(alpha, beta, FUP=6, accrual_rate, median_1, median_2, p = 0.5, t) {
    censor_rate <- 0.5  # initial censoring rate
    censor_cal <- 0     # censoring rate
    prev_censor_rate <- 1  #
    d_ini <- (qnorm(alpha / 2) + qnorm(beta))^2 / ((p * (1 - p)) * (log(median_1 / median_2))^2)
    while (abs(censor_cal - prev_censor_rate) > 0.05) {


      n_ini <- d_ini / censor_rate


      if (ceiling(n_ini) %% 2 == 0) {
        nmax <- ceiling(n_ini)*p
      } else {
        nmax <- (ceiling(n_ini) + 1)*p
      }


      censor_cal <- (tail(Event_prop(nmax, t, median_1, median_2, FUP, accrual_rate)$event_proportions,1) +
                       tail(Event_prop(n_ini-nmax, t, median_1, median_1, FUP, accrual_rate)$event_proportions,1)) / 2


      prev_censor_rate <- censor_rate
      censor_rate <- censor_cal
    }
    piecewise_exponential_median <- function(lambda_0, lambda_1, S) {

      target_survival <- 0.5


      T1 <- -log(target_survival) / lambda_0
      if (T1 < S) {

        median_survival <- T1
      } else {

        median_survival <- S - (log(target_survival) + lambda_0 * S) / lambda_1
      }

      return(median_survival)
    }
    sample_size=(median_1+median_2)/(piecewise_exponential_median(log(2)/median_1,log(2)/median_2,S=t)+median_1)*nmax
    return(list(initial_sample_size = ceiling(sample_size)))
  }

  Interim_sample_size <- function(nmax, gprior.E_1,gprior.E_2, median.1, median.2,S,trunc.para,L,U, weight = 0.5,
                                  lambda, gamma, FUP = 6, accrual_rate, n_sim = 10000) {
    alpha <- 3
    lower <- floor(nmax * 0.5)  # Set lower bound for search
    upper <- ceiling(nmax * 0.9)  # Set upper bound for search
    best_n_interim <- lower  # Initialize the best n.interim
    best_metric <- Inf  # Initialize the smallest combined metric

    while (lower <= upper) {
      mid <- floor((lower + upper) / 2)  # Take the midpoint of the current range
      n.interim <- c(mid, nmax)  # Define interim sample size

      # Calculate the combined metric for the midpoint
      metric_mid <- getoc_2arm_piecewise(
        gprior.E_1 = gprior.E_1,
        gprior.E_2 = gprior.E_2 ,
        median.true=c(median.1,median.2),
        Uniform = FALSE,
        lambda = lambda,
        gamma = gamma,
        n.interim = n.interim,
        L = L,
        U = U,
        S_likely=S,
        FUP = FUP,
        trunc.para = trunc.para,
        rate = accrual_rate,
        nsim = n_sim
      )$average.patients / nmax *  weight + (1-getoc_2arm_piecewise(
        gprior.E_1 = gprior.E_1,
        gprior.E_2 = gprior.E_2,
        median.true=c(median.1,median.2),
        Uniform = FALSE,
        lambda = lambda,
        gamma = gamma,
        n.interim = n.interim,
        L = L,
        U = U,
        S_likely=S,
        FUP = FUP,
        trunc.para = trunc.para,
        rate = accrual_rate,
        nsim = n_sim
      )$average.patients / nmax) * (1 - weight)

      # Calculate metrics for mid - 1 and mid + 1
      n.interim_left <- c(mid - 1, nmax)
      n.interim_right <- c(mid + 1, nmax)

      metric_left <- if (mid > lower) getoc_2arm_piecewise(
        gprior.E_1 = gprior.E_1,
        gprior.E_2 = gprior.E_2,
        median.true=c(median.1,median.2),
        Uniform = FALSE,
        lambda = lambda,
        gamma = gamma,
        n.interim = n.interim_left,
        L = L,
        U = U,
        S_likely=S,
        FUP = FUP,
        trunc.para = trunc.para,
        rate = accrual_rate,
        nsim = n_sim
      )$average.patients / nmax * weight + (1-getoc_2arm_piecewise(
        gprior.E_1 = gprior.E_1,
        gprior.E_2 = gprior.E_2,
        median.true=c(median.1,median.2),
        Uniform = FALSE,
        lambda = lambda,
        gamma = gamma,
        n.interim = n.interim_left,
        L = L,
        U = U,
        S_likely=S,
        FUP = FUP,
        trunc.para = trunc.para,
        rate = accrual_rate,
        nsim = n_sim
      )$average.patients / nmax) * (1 - weight) else Inf

      metric_right <- if (mid < upper) getoc_2arm_piecewise(
        gprior.E_1 = gprior.E_1,
        gprior.E_2 = gprior.E_2,
        median.true=c(median.1,median.2),
        Uniform = FALSE,
        lambda = lambda,
        gamma = gamma,
        n.interim = n.interim_right,
        L = L,
        U = U,
        S_likely=S,
        FUP = FUP,
        trunc.para = trunc.para,
        rate = accrual_rate,
        nsim = n_sim
      )$average.patients / nmax * weight + (1-getoc_2arm_piecewise(
        gprior.E_1 = gprior.E_1,
        gprior.E_2 =  gprior.E_2,
        median.true=c(median.1,median.2),
        Uniform = FALSE,
        lambda = lambda,
        gamma = gamma,
        n.interim = n.interim_right,
        L = L,
        U = U,
        S_likely=S,
        FUP = FUP,
        trunc.para = trunc.para,
        rate = accrual_rate,
        nsim = n_sim
      )$average.patients / nmax) * (1 - weight) else Inf

      # Update the best metric and interim sample size if the current metric is smaller
      if (metric_mid < best_metric) {
        best_metric <- metric_mid
        best_n_interim <- mid
      }

      # Adjust search range based on the comparison of metrics
      if (metric_left > metric_mid && metric_right > metric_mid) {
        break  # Found the optimal n.interim
      } else if (metric_left < metric_mid) {
        upper <- mid - 1  # Search on the left side
      } else {
        lower <- mid + 1  # Search on the right side
      }
    }

    return(list(n_interim = best_n_interim, combined_metric = best_metric))
  }
  if (!is.null(seed)) {
    set.seed(seed)
  }
  nmax_temp<-Initial_sample_size(alpha = err1,beta=err2,accrual_rate = rate,median_1=median_1, median_2=median_2, p = 0.5, t=U)$initial_sample_size
  n.interim = c(floor(nmax_temp*0.7), nmax_temp)
  result <- get.optimal_2arm_piecewise(median.1, median.2,
                                       gprior.E_1 = gprior.E_1,
                                       gprior.E_2 = gprior.E_2,
                                       L = L, U = U,S_likely=S_likely,
                                       Uniform = FALSE, trunc.para = trunc.para,
                                       err1 = err1,
                                       n.interim = n.interim,
                                       rate = rate, FUP = FUP, track = FALSE,
                                       nsim = 10000, control = T,
                                       control.point = c(L, U))
  interim_result<- Interim_sample_size(nmax_temp, median.1=median.1 , median.2=median.2,gprior.E_1=gprior.E_1,gprior.E_2=gprior.E_2,  S=S_likely,L=L,U=U, weight = weight,trunc.para=trunc.para,
                                       lambda=result$optimal[1], gamma=result$optimal[2], FUP = FUP, accrual_rate=rate, n_sim = 10000)
  n.interim=c(interim_result$n_interim, nmax_temp)
  result <- get.optimal_2arm_piecewise(median.1, median.2,
                                       gprior.E_1 = gprior.E_1,
                                       gprior.E_2 = gprior.E_2,
                                       L = L, U = U,S_likely=S_likely,
                                       Uniform = FALSE, trunc.para = trunc.para,
                                       err1 = err1,
                                       n.interim = n.interim,
                                       rate = rate, FUP = FUP, track = FALSE,
                                       nsim = 10000, control = T,
                                       control.point = c(L, U))
  power<-getoc_2arm_piecewise(gprior.E_1 = gprior.E_1,gprior.E_2=gprior.E_2,median.true=c(median.1,median.2),Uniform=FALSE,lambda=result$optimal[1],gamma=result$optimal[2],n.interim = c(interim_result$n_interim, nmax_temp),L=U,U=U,S_likely=S_likely,FUP=FUP,trunc.para=trunc.para,rate=rate,nsim=10000)$reject
  while(power<=(1-err2)){
    if(method=="optimal"){
      nmax_temp=nmax_temp+5
      n.interim = c(floor(nmax_temp*0.7), nmax_temp)
      result <- get.optimal_2arm_piecewise(median.1, median.2,
                                           gprior.E_1 = gprior.E_1,
                                           gprior.E_2 = gprior.E_2,lambda.seq = seq(0.5,0.975,by=0.025),gamma.seq = seq(0,1,by=0.1),
                                           L = L, U = U,S_likely=S_likely,
                                           Uniform = FALSE, trunc.para = trunc.para,
                                           err1 = err1,
                                           n.interim = n.interim,
                                           rate = rate, FUP = FUP, track = FALSE,
                                           nsim = 10000, control = T,
                                           control.point = c(L, U))
      interim_result<- Interim_sample_size(nmax_temp, median.1=median.1 , median.2=median.2, gprior.E_1=gprior.E_1,gprior.E_2=gprior.E_2, S=S_likely,L=L,U=U, weight = weight,trunc.para=trunc.para,
                                           lambda=result$optimal[1], gamma=result$optimal[2], FUP = FUP, accrual_rate=rate, n_sim = 10000)
      n.interim=c(interim_result$n_interim, nmax_temp)
      result <- get.optimal_2arm_piecewise(median.1, median.2,
                                           gprior.E_1 = gprior.E_1,
                                           gprior.E_2 = gprior.E_2,lambda.seq = seq(0.5,0.975,by=0.025),gamma.seq = seq(0,1,by=0.1),
                                           L = L, U = U,S_likely=S_likely,
                                           Uniform = FALSE, trunc.para = trunc.para,
                                           err1 = err1,
                                           n.interim = n.interim,
                                           rate = rate, FUP = FUP, track = FALSE,
                                           nsim = 10000, control = T,
                                           control.point = c(L, U))
      power<-getoc_2arm_piecewise(gprior.E_1 = gprior.E_1,gprior.E_2=gprior.E_2,median.true=c(median.1,median.2),Uniform=FALSE,lambda=result$optimal[1],gamma=result$optimal[2],n.interim = c(interim_result$n_interim, nmax_temp),L=U,U=U,S_likely=S_likely,FUP=FUP,trunc.para=trunc.para,rate=rate,nsim=10000)$reject
    }
    if(method=="suboptimal"){
      nmax_temp=nmax_temp+5
      interim_result<- Interim_sample_size(nmax_temp, median.1=median.1 , median.2=median.2, gprior.E_1=gprior.E_1,gprior.E_2=gprior.E_2, S=S_likely,L=L,U=U, weight = weight,trunc.para=trunc.para,
                                           lambda=result$optimal[1], gamma=result$optimal[2], FUP = FUP, accrual_rate=rate, n_sim = 10000)
      n.interim=c(interim_result$n_interim, nmax_temp)
      power<-getoc_2arm_piecewise(gprior.E_1 = gprior.E_1,gprior.E_2=gprior.E_2,median.true=c(median.1,median.2),Uniform=FALSE,lambda=result$optimal[1],gamma=result$optimal[2],n.interim = c(interim_result$n_interim, nmax_temp),L=U,U=U,S_likely=S_likely,FUP=FUP,trunc.para=trunc.para,rate=rate,nsim=10000)$reject
    }
  }
  if(method=="suboptimal"){
    n.interim=c(interim_result$n_interim, nmax_temp)
    result <- get.optimal_2arm_piecewise(median.1, median.2,
                                         gprior.E_1 = gprior.E_1,
                                         gprior.E_2 = gprior.E_2,lambda.seq = seq(0.5,0.975,by=0.025),gamma.seq = seq(0,1,by=0.1),
                                         L = L, U = U,S_likely=S_likely,
                                         Uniform = FALSE, trunc.para = trunc.para,
                                         err1 = err1,
                                         n.interim = n.interim,
                                         rate = rate, FUP = FUP, track = FALSE,
                                         nsim = 10000, control = T,
                                         control.point = c(L, U))
  }
  calculate_earlystop_rate <- function(n_interim, nmax_temp, result_optimal) {
    getoc_2arm_piecewise(
      gprior.E_1 = gprior.E_1,
      gprior.E_2 = gprior.E_2,
      median.true = c(median.1, median.1),
      Uniform = FALSE,
      lambda = result_optimal[1],
      gamma = result_optimal[2],
      n.interim = c(n_interim, nmax_temp),
      L = L, U = U, S_likely = S_likely,
      FUP = FUP, trunc.para = trunc.para,
      rate = rate, nsim = 10000
    )$earlystop
  }
  n_interim=interim_result$n_interim
  if (!is.null(earlystop_prob)) {
    earlystop_rate <- calculate_earlystop_rate(n_interim, nmax_temp, result$optimal)

    # Continue adjusting until earlystop_rate meets the predefined threshold
    while (earlystop_rate < earlystop_prob) {

      # Increase nmax_temp if current interim is already near the limit
      nmax_temp <- nmax_temp + 2

      # Define new search boundaries within (50%, 75%) of updated nmax_temp
      lower_bound <- ceiling(nmax_temp * 0.7)
      upper_bound <- floor(nmax_temp * 0.75)

      adjusted <- FALSE

      # Iterate to find the minimal n.interim that meets earlystop_prob
      for (n_candidate in seq(lower_bound, upper_bound)) {
        temp_earlystop_rate <- calculate_earlystop_rate(n_candidate, nmax_temp, result$optimal)

        if (temp_earlystop_rate >= earlystop_prob) {
          n_interim <- n_candidate
          earlystop_rate <- temp_earlystop_rate
          adjusted <- TRUE
          break  # Exit once a suitable interim sample size is found
        }
      }

      # If no suitable interim size found, continue by increasing nmax_temp
      if (adjusted) {
        break  # Exit the while loop if adjustment is successful
      }
    }
  }
  if(!is.null(earlystop_prob)){
    n.interim=c(n_interim, nmax_temp)
    result <- get.optimal_2arm_piecewise(median.1, median.2,
                                         gprior.E_1 = gprior.E_1,
                                         gprior.E_2 = gprior.E_2,lambda.seq = seq(0.5,0.975,by=0.025),gamma.seq = seq(0,1,by=0.1),
                                         L = L, U = U,S_likely=S_likely,
                                         Uniform = FALSE, trunc.para = trunc.para,
                                         err1 = err1,
                                         n.interim = n.interim,
                                         rate = rate, FUP = FUP, track = FALSE,
                                         nsim = 10000, control = T,
                                         control.point = c(L, U))
  }

  if(is.null(earlystop_prob)){
    return(list(sample_size=c(interim_result$n_interim,nmax_temp),optimal=result$optimal))}
  if (!is.null(earlystop_prob)){
    return(list(sample_size = c(n_interim, nmax_temp), optimal = result$optimal))
  }
}
