/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.continuous;

import dr.evolution.continuous.Continuous;
import dr.evolution.continuous.Contrastable;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.MutableTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleNode;
import dr.evolution.tree.SimpleTree;
import dr.geo.math.SphericalPolarCoordinates;
import dr.matrix.Matrix;
import dr.matrix.MutableMatrix;
import java.io.StringReader;

public class ContinuousTraitLikelihood {
    public double calculateLikelihood(MutableTree mutableTree, String[] stringArray, Contrastable[] contrastableArray, double d) {
        ContrastedTraitNode contrastedTraitNode = new ContrastedTraitNode(mutableTree, mutableTree.getRoot(), stringArray);
        contrastedTraitNode.calculateContrasts(d);
        for (int i = 0; i < contrastableArray.length; ++i) {
            contrastableArray[i] = contrastedTraitNode.getTraitValue(i);
        }
        return this.calculateTraitsLikelihood(contrastedTraitNode);
    }

    private double calculateTraitsLikelihood(ContrastedTraitNode contrastedTraitNode) {
        int n = contrastedTraitNode.getTraitCount();
        if (n == 1) {
            return this.calculateSingleTraitLikelihood(contrastedTraitNode);
        }
        return this.calculateMultipleTraitsLikelihood(contrastedTraitNode, n);
    }

    private double calculateMultipleTraitsLikelihood(ContrastedTraitNode contrastedTraitNode, int n) {
        double d;
        SimpleTree simpleTree = new SimpleTree(contrastedTraitNode);
        double[][] dArray = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = i; j < n; ++j) {
                d = 0.0;
                for (int k = 0; k < simpleTree.getInternalNodeCount(); ++k) {
                    ContrastedTraitNode contrastedTraitNode2 = (ContrastedTraitNode)simpleTree.getInternalNode(k);
                    d += contrastedTraitNode2.contrast[i] * contrastedTraitNode2.contrast[j] / contrastedTraitNode2.contrastVariance;
                }
                dArray[i][j] = d /= (double)simpleTree.getInternalNodeCount();
                dArray[j][i] = d;
            }
        }
        MutableMatrix mutableMatrix = Matrix.Util.createMutableMatrix(new double[1][1]);
        MutableMatrix mutableMatrix2 = Matrix.Util.createMutableMatrix(dArray);
        d = 0.0;
        try {
            d = Matrix.Util.det(mutableMatrix2);
        }
        catch (Matrix.NotSquareException notSquareException) {
            notSquareException.printStackTrace(System.out);
        }
        MutableMatrix mutableMatrix3 = Matrix.Util.createMutableMatrix(dArray);
        try {
            Matrix.Util.invert(mutableMatrix3);
        }
        catch (Matrix.NotSquareException notSquareException) {
            notSquareException.printStackTrace(System.out);
        }
        double d2 = 0.0;
        int n2 = simpleTree.getInternalNodeCount() + 1;
        for (int i = 0; i < simpleTree.getInternalNodeCount(); ++i) {
            ContrastedTraitNode contrastedTraitNode3 = (ContrastedTraitNode)simpleTree.getInternalNode(i);
            double[] dArray2 = contrastedTraitNode3.getTraitContrasts();
            Matrix matrix = Matrix.Util.createRowVector(dArray2);
            Matrix matrix2 = Matrix.Util.createColumnVector(dArray2);
            try {
                Matrix.Util.product(mutableMatrix3, matrix2, mutableMatrix2);
                Matrix.Util.product(matrix, mutableMatrix2, mutableMatrix);
            }
            catch (Matrix.WrongDimensionException wrongDimensionException) {
                wrongDimensionException.printStackTrace(System.out);
            }
            d2 += mutableMatrix.getElement(0, 0) / contrastedTraitNode3.getContrastVariance();
            d2 += (double)n * Math.log(contrastedTraitNode3.getContrastVariance());
        }
        d2 += (double)n * Math.log(contrastedTraitNode.getNodeVariance());
        d2 += (double)n2 * Math.log(d);
        d2 += (double)(n2 * n) * Math.log(Math.PI * 2);
        d2 = -d2 / 2.0;
        return d2;
    }

    private double calculateSingleTraitLikelihood(ContrastedTraitNode contrastedTraitNode) {
        int n;
        SimpleTree simpleTree = new SimpleTree(contrastedTraitNode);
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < simpleTree.getInternalNodeCount(); ++i) {
            ContrastedTraitNode contrastedTraitNode2 = (ContrastedTraitNode)simpleTree.getInternalNode(i);
            double d4 = contrastedTraitNode2.getTraitContrasts()[0];
            double d5 = contrastedTraitNode2.getContrastVariance();
            d2 += d4 * d4 / d5;
            d3 += Math.log(d5);
            if (!contrastedTraitNode2.isRoot()) continue;
            d3 += Math.log(contrastedTraitNode2.getNodeVariance());
        }
        double d6 = 0.0;
        for (n = 0; n < simpleTree.getNodeCount(); ++n) {
            NodeRef nodeRef = simpleTree.getNode(n);
            if (simpleTree.isRoot(nodeRef)) continue;
            d6 += simpleTree.getBranchLength(nodeRef);
        }
        d = d2 / (double)simpleTree.getInternalNodeCount();
        n = simpleTree.getInternalNodeCount() + 1;
        double d7 = (double)n * Math.log(Math.PI * 2 * d);
        d7 += d3;
        d7 += d2 / d;
        d7 = -d7 / 2.0;
        return d7;
    }

    public static void main(String[] stringArray) throws Exception {
        String string = "((A:1, B:1):1,(C:1, D:1):1);";
        NewickImporter newickImporter = new NewickImporter(new StringReader(string));
        MutableTree mutableTree = (MutableTree)newickImporter.importTree(null);
        mutableTree.setTaxonAttribute(0, "U1", new Continuous(1.1));
        mutableTree.setTaxonAttribute(1, "U1", new Continuous(1.95));
        mutableTree.setTaxonAttribute(2, "U1", new Continuous(3.15));
        mutableTree.setTaxonAttribute(3, "U1", new Continuous(4.39));
        mutableTree.setTaxonAttribute(0, "U2", new Continuous(5.2));
        mutableTree.setTaxonAttribute(1, "U2", new Continuous(3.8));
        mutableTree.setTaxonAttribute(2, "U2", new Continuous(3.1));
        mutableTree.setTaxonAttribute(3, "U2", new Continuous(1.95));
        ContinuousTraitLikelihood continuousTraitLikelihood = new ContinuousTraitLikelihood();
        Contrastable[] contrastableArray = new Contrastable[2];
        double d = continuousTraitLikelihood.calculateLikelihood(mutableTree, new String[]{"U1", "U2"}, contrastableArray, 1.0);
        System.out.println("logL = " + d);
        System.out.println("mle(trait1) = " + contrastableArray[0]);
        System.out.println("mle(trait2) = " + contrastableArray[1]);
        Contrastable[] contrastableArray2 = new Contrastable[1];
        System.out.println("logL (trait1) = " + continuousTraitLikelihood.calculateLikelihood(mutableTree, new String[]{"U1"}, contrastableArray2, 1.0));
        System.out.println("mle(trait1) = " + contrastableArray2[0]);
        System.out.println("logL (trait2) = " + continuousTraitLikelihood.calculateLikelihood(mutableTree, new String[]{"U2"}, contrastableArray2, 1.0));
        System.out.println("mle(trait2) = " + contrastableArray2[0]);
    }

    class ContrastedTraitNode
    extends SimpleNode {
        private double[] contrast;
        private double contrastVariance;
        private Contrastable[] traitValue;
        private double nodeVariance;
        private MutableTree tree;
        private NodeRef node;
        private String[] traitNames;

        public ContrastedTraitNode(MutableTree mutableTree, NodeRef nodeRef, String[] stringArray) {
            this.init(mutableTree, nodeRef, stringArray.length);
            if (!mutableTree.isExternal(nodeRef)) {
                if (mutableTree.getChildCount(nodeRef) != 2) {
                    throw new IllegalArgumentException("Tree must be strictly bifurcating!");
                }
                this.addChild(new ContrastedTraitNode(mutableTree, mutableTree.getChild(nodeRef, 0), stringArray));
                this.addChild(new ContrastedTraitNode(mutableTree, mutableTree.getChild(nodeRef, 1), stringArray));
            } else {
                for (int i = 0; i < stringArray.length; ++i) {
                    Object object = mutableTree.getNodeTaxon(nodeRef).getAttribute(stringArray[i]);
                    if (object == null) {
                        throw new IllegalArgumentException("attribute " + stringArray[i] + " does not exist in " + mutableTree.getTaxonId(nodeRef.getNumber()));
                    }
                    if (object instanceof Number) {
                        this.traitValue[i] = new Continuous(((Number)object).doubleValue());
                    } else if (object instanceof String) {
                        this.traitValue[i] = new Continuous(Double.parseDouble((String)object));
                    } else if (object instanceof Continuous) {
                        this.traitValue[i] = (Continuous)object;
                    } else if (object instanceof SphericalPolarCoordinates) {
                        this.traitValue[i] = (SphericalPolarCoordinates)object;
                    }
                    mutableTree.setNodeAttribute(nodeRef, stringArray[i], this.traitValue[i]);
                }
            }
            this.traitNames = stringArray;
        }

        private void init(MutableTree mutableTree, NodeRef nodeRef, int n) {
            this.setHeight(mutableTree.getNodeHeight(nodeRef));
            this.setRate(mutableTree.getNodeRate(nodeRef));
            this.setId(mutableTree.getTaxonId(nodeRef.getNumber()));
            this.setNumber(nodeRef.getNumber());
            this.setTaxon(mutableTree.getNodeTaxon(nodeRef));
            this.contrast = new double[n];
            this.contrastVariance = 0.0;
            this.traitValue = new Contrastable[n];
            this.nodeVariance = 0.0;
            this.tree = mutableTree;
            this.node = nodeRef;
        }

        public double[] getTraitContrasts() {
            return this.contrast;
        }

        public double getContrastVariance() {
            return this.contrastVariance;
        }

        public double getNodeVariance() {
            return this.nodeVariance;
        }

        public Contrastable getTraitValue(int n) {
            return this.traitValue[n];
        }

        public int getTraitCount() {
            return this.traitValue.length;
        }

        private void calculateContrasts(double d) {
            if (!this.isExternal()) {
                ContrastedTraitNode contrastedTraitNode = (ContrastedTraitNode)this.getChild(0);
                ContrastedTraitNode contrastedTraitNode2 = (ContrastedTraitNode)this.getChild(1);
                contrastedTraitNode.calculateContrasts(d);
                contrastedTraitNode2.calculateContrasts(d);
                double d2 = contrastedTraitNode.nodeVariance + Math.pow(this.getHeight() - contrastedTraitNode.getHeight(), d);
                double d3 = contrastedTraitNode2.nodeVariance + Math.pow(this.getHeight() - contrastedTraitNode2.getHeight(), d);
                this.contrastVariance = d2 + d3;
                this.nodeVariance = d2 * d3 / (d2 + d3);
                double d4 = 1.0 / d2;
                double d5 = 1.0 / d3;
                for (int i = 0; i < this.getTraitCount(); ++i) {
                    this.contrast[i] = contrastedTraitNode.traitValue[i].getDifference(contrastedTraitNode2.traitValue[i]);
                    this.traitValue[i] = contrastedTraitNode.traitValue[i].getWeightedMean(d4, contrastedTraitNode.traitValue[i], d5, contrastedTraitNode2.traitValue[i]);
                    this.tree.setNodeAttribute(this.node, this.traitNames[i], this.traitValue[i]);
                }
            }
        }
    }
}

