/*
 * Decompiled with CFR 0.152.
 */
package org.drugis.mtc.convergence;

import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.drugis.mtc.MCMCResults;
import org.drugis.mtc.Parameter;
import org.drugis.mtc.summary.SummaryUtil;
import org.drugis.mtc.util.WindowResults;

public class GelmanRubinConvergence {
    private MCMCResults d_results;
    private Parameter d_parameter;
    private static Mean s_mean = new Mean();
    private static Variance s_var = new Variance();

    public GelmanRubinConvergence(MCMCResults results, Parameter parameter) {
        assert (results.getNumberOfSamples() % 2 == 0);
        this.d_results = results;
        this.d_parameter = parameter;
    }

    public static double diagnose(MCMCResults results, Parameter parameter) {
        return new GelmanRubinConvergence(results, parameter).getCorrPSRF();
    }

    public static double diagnose(MCMCResults results, Parameter parameter, int nSamples) {
        return GelmanRubinConvergence.diagnose(new WindowResults(results, 0, nSamples), parameter);
    }

    public static double calculatePooledVariance(MCMCResults results, Parameter parameter, int nSamples) {
        return new GelmanRubinConvergence(new WindowResults(results, 0, nSamples), parameter).getVHat();
    }

    public static double calculateWithinChainVariance(MCMCResults results, Parameter parameter, int nSamples) {
        return new GelmanRubinConvergence(new WindowResults(results, 0, nSamples), parameter).getWithinChainVar();
    }

    public double oneChainMean(int c) {
        return SummaryUtil.evaluate(s_mean, SummaryUtil.getOneChainLastHalfSamples(this.d_results, this.d_parameter, c));
    }

    public double oneChainVar(int c) {
        return SummaryUtil.evaluate(s_var, SummaryUtil.getOneChainLastHalfSamples(this.d_results, this.d_parameter, c));
    }

    public double allChainMean() {
        return SummaryUtil.evaluate(s_mean, SummaryUtil.getAllChainsLastHalfSamples(this.d_results, this.d_parameter));
    }

    public double getBetweenChainVar() {
        double var = 0.0;
        double mean = this.allChainMean();
        int i = 0;
        while (i < this.d_results.getNumberOfChains()) {
            var += Math.pow(this.oneChainMean(i) - mean, 2.0);
            ++i;
        }
        return (double)this.d_results.getNumberOfSamples() * var / 2.0 / (double)(this.d_results.getNumberOfChains() - 1);
    }

    public int getNSamples() {
        return this.d_results.getNumberOfSamples() / 2;
    }

    public int getNChains() {
        return this.d_results.getNumberOfChains();
    }

    public double getWithinChainVar() {
        return s_mean.evaluate(this.getVariances());
    }

    public double[] getVariances() {
        double[] tmp = new double[this.getNChains()];
        int i = 0;
        while (i < this.getNChains()) {
            tmp[i] = this.oneChainVar(i);
            ++i;
        }
        return tmp;
    }

    public double[] getMeans() {
        double[] tmp = new double[this.getNChains()];
        int i = 0;
        while (i < this.getNChains()) {
            tmp[i] = this.oneChainMean(i);
            ++i;
        }
        return tmp;
    }

    public double getSigmaSquaredHat() {
        int n = this.getNSamples();
        return this.getWithinChainVar() * (double)(n - 1) / (double)n + this.getBetweenChainVar() / (double)n;
    }

    public double getVHat() {
        return this.getSigmaSquaredHat() + this.getBetweenChainVar() / (double)(this.d_results.getNumberOfChains() * this.getNSamples());
    }

    public double getCorrPSRF() {
        double d = this.getDegreesOfFreedom();
        double dfactor = (d + 3.0) / (d + 1.0);
        return Math.sqrt(dfactor * this.getVHat() / this.getWithinChainVar());
    }

    public double getDegreesOfFreedom() {
        double m = this.getNChains();
        double n = this.getNSamples();
        Covariance cov = new Covariance();
        double[] squaredMeans = this.getMeans();
        int i = 0;
        while (i < this.getNChains()) {
            int n2 = i;
            squaredMeans[n2] = squaredMeans[n2] * squaredMeans[i];
            ++i;
        }
        double varW = s_var.evaluate(this.getVariances()) / m;
        double varB = 2.0 * this.getBetweenChainVar() * this.getBetweenChainVar() / (m - 1.0);
        double covWB = n / m * (cov.covariance(this.getVariances(), squaredMeans) - 2.0 * this.allChainMean() * cov.covariance(this.getVariances(), this.getMeans()));
        double varV = (Math.pow(n - 1.0, 2.0) * varW + Math.pow(1.0 + 1.0 / m, 2.0) * varB + 2.0 * (n - 1.0) * (1.0 + 1.0 / m) * covWB) / (n * n);
        return 2.0 * this.getVHat() * this.getVHat() / varV;
    }
}

