# get a 2D lookup matrix mapping cells of the transition matrix to indices of a flattened and condensed rate vector (whose entries are unique independent rate variables of the model)
# some entries in the index_matrix (namely the ones on the diagonal) will map to zero, which means that they are not represented in the condensed rate vector
# the returned index_matrix is to be used in conjunction with the function get_transition_matrix_from_rate_vector(..)
# The convention used is the same as by the function ape::ace(), i.e. index.matrix[] is synchronized with the transition_matrix
# rate_model can be "ER" or "SYM" or "ARD" or "SUEDE" or a custom index_matrix as if it was generated by this function (in which case it is merely used to determine Nrates)
get_transition_index_matrix = function(Nstates, rate_model){
	if (is.character(rate_model)) {
		index_matrix = matrix(0, Nstates, Nstates)
		if(rate_model == "ER"){
			Nrates = index_matrix[] = 1;
			diag(index_matrix) = 0;
		
		}else if(rate_model == "ARD"){
			Nrates = Nstates * (Nstates - 1);
			index_matrix[col(index_matrix) != row(index_matrix)] = 1:Nrates;
		
		}else if(rate_model == "SYM"){
			Nrates 							= Nstates * (Nstates - 1)/2;
			lower_diagonal 					= (col(index_matrix) < row(index_matrix));
			index_matrix[lower_diagonal] 	= 1:Nrates
			index_matrix 					= t(index_matrix)
			index_matrix[lower_diagonal] 	= 1:Nrates

		}else if(rate_model =="SUEDE"){
			# only sequential transitions allowed (i-->i+1, i-->i-1), all up-rates are equal, all down-rates are equal
			Nrates = 2;
			index_matrix[col(index_matrix) == row(index_matrix)+1] = 1; # "up" rate, i-->i+1
			index_matrix[col(index_matrix) == row(index_matrix)-1] = 2; # "down" rate, i-->i-1

		}else if(rate_model =="SRD"){
			# only sequential transitions allowed (i-->i+1, i-->i-1), all those rates can be different
			Nrates = 2*(Nstates-1);
			index_matrix[col(index_matrix) == row(index_matrix)+1] = 1:(Nstates-1); 		# "up" rates, i-->i+1
			index_matrix[col(index_matrix) == row(index_matrix)-1] = Nstates:(2*Nstates-2); # "down" rates, i-->i-1
		}else{
			stop(sprintf("ERROR: Unknown rate_model '%s'",rate_model))
		}
	}else{
		if(ncol(rate_model)!=Nstates) stop(sprintf("ERROR: Wrong number of columns in rate model (expected %d, found %d)",Nstates,ncol(rate_model)));
		if(nrow(rate_model)!=Nstates) stop(sprintf("ERROR: Wrong number of rows in rate model (expected %d, found %d)",Nstates,nrow(rate_model)));
		index_matrix = rate_model
		Nrates = max(rate_model)
	}
        
	return(list(index_matrix=index_matrix, Nrates=Nrates))
}
