#' A function to plot the fitted mortality rates, accompanied by credible intervals, from posterior samples generated for stochastic mortality models
#'
#' Plot the fitted mortality rates, accompanied by credible intervals (user-specified level), using posterior samples stored in "fit_result" object.
#' @param result object of type either "fit_result" or "BayesMoFo".
#' @param pred_int A numeric value (between 0 and 1) specifying the credible level of uncertainty bands. Default is \code{pred_int=0.95} (\eqn{95\%} intervals). 
#' @param plot_type A character string (\code{c("age","time")}) to indicate whether to plot by age (default) or by time/year.
#' @param plot_ages A numeric vector specifying which range of ages to plot for visualisation. If not specified, use whatever ages that were used to fit the model (i.e. \code{fit_result$death$ages}). One panel will be constructed per age when \code{plot_type="time"}, with a maximum of nine panels. If exceeded, only the first nine ages will be plotted.
#' @param plot_years A numeric vector specifying which range of years to plot for visualisation. If not specified, use whatever years that were used to fit the model (i.e. \code{fit_result$death$years}). One panel will be constructed per year when \code{plot_type="age"}, with a maximum of nine panels. If exceeded, only the first nine years will be plotted.
#' @param legends A logical value to indicate if legends of the plots should be shown (default) or suppressed (e.g. to aid visibility).
#' @return A plot illustrating the median fitted and forecast mortality rates, accompanied by credible intervals.
#' @keywords graphics visualization plots
#' @concept fitted death rates
#' @concept forecast death rates
#' @concept credible intervals
#' @importFrom graphics legend lines par points
#' @importFrom stats dbinom dpois quantile sd
#' @export
#' @examples
#' \donttest{
#' #load and prepare data
#' data("dxt_array_product");data("Ext_array_product")
#' death<-preparedata_fn(dxt_array_product,strat_name = c("ACI","DB","SCI"),ages=35:65)
#' expo<-preparedata_fn(Ext_array_product,strat_name = c("ACI","DB","SCI"),ages=35:65)
#' 
#' #fit any mortality model
#' runBayesMoFo_result<-runBayesMoFo(death=death,expo=expo,models="APCI",n_iter=1000,forecast=TRUE)
#' 
#' #default plot
#' plot_rates_fn(runBayesMoFo_result)
#' 
#' #plot by age and changing pre-specified arguments 
#' plot_rates_fn(runBayesMoFo_result,pred_int=0.8,plot_ages=40:60,plot_years=c(2017,2020))
#' 
#' #plot by time/year
#' plot_rates_fn(runBayesMoFo_result,plot_type="time",plot_ages=c(40,50,60))
#' }

plot_rates_fn<-function(result,pred_int=0.95,plot_type="age",plot_ages=NULL,plot_years=NULL,legends=TRUE){

  if ("BayesMoFo_obj" %in% names(result)){
    fit_result<-result$result$best
  } else {
    fit_result<-result
  }
  
  death<-fit_result$death
  expo_data<-fit_result$expo$data
  if (fit_result$family=="binomial"){expo_data<-round(expo_data+0.5*death$data)}
  
  mcmc_object<-coda::as.mcmc(as.matrix(fit_result$post_sample))
  
  probs<-c(0.5*(1-pred_int),0.5,1-0.5*(1-pred_int))
  
  p<-death$n_strat
  A<-death$n_ages
  T<-death$n_years  
  h<-fit_result$h
  
  forecast<-fit_result$forecast
  
  n<-dim(mcmc_object)[1]
  
  p_names<-dimnames(death$data)[[1]]
  ages_names<-dimnames(death$data)[[2]]
  years_names<-dimnames(death$data)[[3]]
  
  crude_rates<-death$data/expo_data
  crude_rates<-provideDimnames(crude_rates,base=list(p_names,ages_names,years_names))
  
  if (forecast){
    rates_lower<-array(dim=c(p,A,T+h))
    rates_median<-array(dim=c(p,A,T+h))
    rates_upper<-array(dim=c(p,A,T+h))
    rates_lower<-provideDimnames(rates_lower,base=list(p_names,ages_names,c(years_names,as.character(death$years[T]+(1:h)))))
    rates_median<-provideDimnames(rates_median,base=list(p_names,ages_names,c(years_names,as.character(death$years[T]+(1:h)))))
    rates_upper<-provideDimnames(rates_upper,base=list(p_names,ages_names,c(years_names,as.character(death$years[T]+(1:h)))))
    interval_name<-c("lower","median","upper")
    
    rates_mat<-matrix(0,nrow=n,ncol=A*(T+h))
    
    rates_mat<-mcmc_object[,startsWith(colnames(mcmc_object),"q")]
    
    rates_pn<-apply(rates_mat,2,quantile_fn,probs=probs)
    
    for (i in 1:p){
      for(j in 1:A){
        for(k in 1:(T+h)){
          index<-paste0("q[",i,",",j,",",k,"]")
          rates_lower[i,j,k]<-rates_pn[1,index]
          rates_median[i,j,k]<-rates_pn[2,index]
          rates_upper[i,j,k]<-rates_pn[3,index]
        }
      }
    }
  } else {rates_lower<-array(dim=c(p,A,T))
  rates_median<-array(dim=c(p,A,T))
  rates_upper<-array(dim=c(p,A,T))
  rates_lower<-provideDimnames(rates_lower,base=list(p_names,ages_names,years_names))
  rates_median<-provideDimnames(rates_median,base=list(p_names,ages_names,years_names))
  rates_upper<-provideDimnames(rates_upper,base=list(p_names,ages_names,years_names))
  interval_name<-c("lower","median","upper")
  
  rates_mat<-matrix(0,nrow=n,ncol=A*T)
  
  rates_mat<-mcmc_object[,startsWith(colnames(mcmc_object),"q")]
  
  rates_pn<-apply(rates_mat,2,quantile_fn,probs=probs)
  
  for (i in 1:p){
    for(j in 1:A){
      for(k in 1:T){
        index<-paste0("q[",i,",",j,",",k,"]")
        rates_lower[i,j,k]<-rates_pn[1,index]
        rates_median[i,j,k]<-rates_pn[2,index]
        rates_upper[i,j,k]<-rates_pn[3,index]
      }
    }
  }}
  
  if (is.null(plot_ages)){plot_ages=death$ages}

  if (is.null(plot_years)){
    if (forecast){plot_years=c(death$years,death$years[T]+(1:h))} else {
    plot_years=death$years}
  }
  
  oldpar <- par(no.readonly = TRUE) 
  on.exit(par(oldpar))
 
  #plot by age
  if (plot_type=="age"){
    length_years<-length(plot_years)
    if (length_years<=3){
      par(mfrow=c(1,length_years))}else if(length_years>3 & length_years<=6){
        par(mfrow=c(2,ceiling(length_years/2)))
      }else if(length_years>6 & length_years<=9){
        par(mfrow=c(3,3))
      }else{
        par(mfrow=c(3,3))
        warning("Too many years selected, only printing the first 9 years.")
        plot_years<-plot_years[1:9]
      }
    
    yrange_plot<-range(log(rates_pn))
    
    for (i in 1:length(plot_years)){
      plot(NULL,xlim=range(plot_ages),ylim=yrange_plot,main=plot_years[i],xlab="age",ylab="log death rates")
      for (j in 1:p){
        lines(plot_ages,log(rates_median[j,as.character(plot_ages),as.character(plot_years[i])]),type="l",col=(j+1));lines(plot_ages,log(rates_lower[j,as.character(plot_ages),as.character(plot_years[i])]),lty=2,col=(j+1));lines(plot_ages,log(rates_upper[j,as.character(plot_ages),as.character(plot_years[i])]),lty=2,col=(j+1))
        if (plot_years[i]<=death$years[T]){
        points(plot_ages,log(crude_rates[j,as.character(plot_ages),as.character(plot_years[i])]),col=(j+1),pch=19,cex=0.5)}
        if (legends){legend("bottomright",p_names,lty=1,col=((1:p)+1))}
        
      }}
  }
  
  #plot by time/year
  if (plot_type=="time"){
    length_ages<-length(plot_ages)
    if (length_ages<=3){
      par(mfrow=c(1,length_ages))}else if(length_ages>3 & length_ages<=6){
        par(mfrow=c(2,ceiling(length_ages/2)))
      }else if(length_ages>6 & length_ages<=9){
        par(mfrow=c(3,3))
      }else{
        par(mfrow=c(3,3))
        warning("Too many ages selected, only printing the first 9 ages.")
        plot_ages<-plot_ages[1:9]
      }
    
    yrange_plot<-range(log(rates_pn))
    
    for (i in 1:length(plot_ages)){
      plot(NULL,xlim=range(plot_years),ylim=yrange_plot,main=paste0("Age=",plot_ages[i]),xlab="year",ylab="log death rates")
      for (j in 1:p){
        lines(plot_years,log(rates_median[j,as.character(plot_ages[i]),as.character(plot_years)]),type="l",col=(j+1));lines(plot_years,log(rates_lower[j,as.character(plot_ages[i]),as.character(plot_years)]),lty=2,col=(j+1));lines(plot_years,log(rates_upper[j,as.character(plot_ages[i]),as.character(plot_years)]),lty=2,col=(j+1))
        plot_years_star<-plot_years[plot_years<=death$years[T]]
        points(plot_years_star,log(crude_rates[j,as.character(plot_ages[i]),as.character(plot_years_star)]),col=(j+1),pch=19,cex=0.75)
        if (legends){legend("bottomright",p_names,lty=1,col=((1:p)+1))}
        if (forecast){abline(v=death$years[T],lty=3)}
      }}
  }
  invisible(gc())
}