/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.smooth;

import dr.evomodel.coalescent.smooth.SkyGlideLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.List;

public class SkyGlideGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable {
    private final SkyGlideLikelihood likelihood;
    private final Parameter parameter;
    private final WrtParameter wrtParameter;
    private final double tolerance;
    private int treeIndex = -1;

    public SkyGlideGradient(SkyGlideLikelihood skyGlideLikelihood, Parameter parameter, double d) {
        this.likelihood = skyGlideLikelihood;
        this.parameter = parameter;
        this.wrtParameter = this.factory(parameter);
        this.tolerance = d;
    }

    private WrtParameter factory(Parameter parameter) {
        if (parameter == this.likelihood.getLogPopSizeParameter()) {
            return WrtParameter.LOG_POP_SIZE;
        }
        if (parameter instanceof NodeHeightProxyParameter) {
            List<TreeModel> list = this.likelihood.getTrees();
            TreeModel treeModel = ((NodeHeightProxyParameter)parameter).getTree();
            for (int i = 0; i < list.size(); ++i) {
                if (list.get(i) != treeModel) continue;
                this.treeIndex = i;
                return WrtParameter.NODE_HEIGHT;
            }
            throw new RuntimeException("Parameter not recognized.");
        }
        throw new RuntimeException("Parameter not recognized.");
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.likelihood, this.treeIndex);
    }

    @Override
    public String getReport() {
        String string = GradientWrtParameterProvider.getReportAndCheckForError(this, this.wrtParameter.getParameterLowerBound(), this.wrtParameter.getParameterUpperBound(), this.tolerance) + "\n" + HessianWrtParameterProvider.getReportAndCheckForError(this, this.wrtParameter == WrtParameter.NODE_HEIGHT ? null : Double.valueOf(this.tolerance));
        return string;
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return this.wrtParameter.getDiagonalHessianLogDensity(this.likelihood, this.treeIndex);
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented.");
    }

    public static enum WrtParameter {
        LOG_POP_SIZE{

            @Override
            double[] getGradientLogDensity(SkyGlideLikelihood skyGlideLikelihood, int n) {
                return skyGlideLikelihood.getGradientWrtLogPopulationSize();
            }

            @Override
            double[] getDiagonalHessianLogDensity(SkyGlideLikelihood skyGlideLikelihood, int n) {
                return skyGlideLikelihood.getDiagonalHessianLogDensityWrtLogPopSize();
            }

            @Override
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            double getParameterUpperBound() {
                return Double.POSITIVE_INFINITY;
            }
        }
        ,
        NODE_HEIGHT{

            @Override
            double[] getGradientLogDensity(SkyGlideLikelihood skyGlideLikelihood, int n) {
                return skyGlideLikelihood.getGradientWrtNodeHeight(n);
            }

            @Override
            double[] getDiagonalHessianLogDensity(SkyGlideLikelihood skyGlideLikelihood, int n) {
                return skyGlideLikelihood.getDiagonalHessianWrtNodeHeight(n);
            }

            @Override
            double getParameterLowerBound() {
                return 0.0;
            }

            @Override
            double getParameterUpperBound() {
                return Double.POSITIVE_INFINITY;
            }
        };


        abstract double[] getGradientLogDensity(SkyGlideLikelihood var1, int var2);

        abstract double[] getDiagonalHessianLogDensity(SkyGlideLikelihood var1, int var2);

        abstract double getParameterLowerBound();

        abstract double getParameterUpperBound();
    }
}

