/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.inference.distribution.NormalStatisticsProvider;
import dr.inference.model.Parameter;
import dr.inference.model.ScaledMatrixParameter;
import dr.inference.model.ScaledParameter;
import dr.inference.model.TransformedParameter;
import dr.inference.operators.repeatedMeasures.MultiplicativeGammaGibbsHelper;
import dr.math.distributions.MultivariateGammaLikelihood;
import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class NormalMatrixNormLikelihood
extends MultivariateGammaLikelihood
implements MultiplicativeGammaGibbsHelper,
NormalStatisticsProvider {
    private final int rowDimension;
    private final Parameter columnNorms;
    public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private static final String NORMAL_NORM_LIKELIHOOD = "normalMatrixNormLikelihood";
        private static final String GLOBAL_PRECISION = "globalPrecision";
        private static final String MATRIX = "matrix";

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.hasAttribute("id") ? xMLObject.getId() : NORMAL_NORM_LIKELIHOOD;
            XMLObject xMLObject2 = xMLObject.getChild(GLOBAL_PRECISION);
            Parameter parameter = (Parameter)xMLObject2.getChild(Parameter.class);
            XMLObject xMLObject3 = xMLObject.getChild(MATRIX);
            ScaledMatrixParameter scaledMatrixParameter = (ScaledMatrixParameter)xMLObject3.getChild(ScaledMatrixParameter.class);
            Parameter parameter2 = scaledMatrixParameter.getScaleParameter();
            if (parameter2.getDimension() != parameter.getDimension()) {
                throw new XMLParseException("incompatible dimensions: the globalPrecision parameter in with id `" + parameter.getId() + "` had dimension " + parameter.getDimension() + ", while the " + MATRIX + " parameter with id `" + scaledMatrixParameter.getId() + "` has " + parameter2.getDimension() + " columns.");
            }
            return new NormalMatrixNormLikelihood(string, parameter, parameter2, scaledMatrixParameter.getRowDimension());
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(GLOBAL_PRECISION, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MATRIX, new XMLSyntaxRule[]{new ElementRule(ScaledMatrixParameter.class)})};
        }

        @Override
        public String getParserDescription() {
            return "gamma likelihood on the squared norm of the columns of a matrix";
        }

        @Override
        public Class getReturnType() {
            return NormalMatrixNormLikelihood.class;
        }

        @Override
        public String getParserName() {
            return NORMAL_NORM_LIKELIHOOD;
        }
    };

    public NormalMatrixNormLikelihood(String string, Parameter parameter, Parameter parameter2, int n) {
        super(string, NormalMatrixNormLikelihood.setupShape(n, parameter2.getDimension()), NormalMatrixNormLikelihood.setupScale(parameter), NormalMatrixNormLikelihood.setupSquaredNorms(parameter2));
        this.columnNorms = parameter2;
        this.rowDimension = n;
    }

    private static final Parameter setupShape(int n, int n2) {
        Parameter.Default default_ = new Parameter.Default(n2, (double)n / 2.0);
        return default_;
    }

    private static final Parameter setupScale(Parameter parameter) {
        TransformedParameter transformedParameter = new TransformedParameter(parameter, new Transform.ReciprocalTransform());
        ScaledParameter scaledParameter = new ScaledParameter(2.0, (Parameter)transformedParameter);
        return scaledParameter;
    }

    private static final Parameter setupSquaredNorms(Parameter parameter) {
        TransformedParameter transformedParameter = new TransformedParameter(parameter, new Transform.PowerTransform(2.0));
        return transformedParameter;
    }

    @Override
    public double computeSumSquaredErrors(int n) {
        return this.data.getParameterValue(n);
    }

    @Override
    public Parameter getParameter() {
        return this.columnNorms;
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = super.getGradientLogDensity();
        for (int i = 0; i < this.dim; ++i) {
            int n = i;
            dArray[n] = dArray[n] * (2.0 * this.columnNorms.getParameterValue(i));
        }
        return dArray;
    }

    @Override
    public int getRowDimension() {
        return this.rowDimension;
    }

    @Override
    public int getColumnDimension() {
        return this.dim;
    }

    @Override
    public double getNormalMean(int n) {
        return 0.0;
    }

    @Override
    public double getNormalSD(int n) {
        int n2 = n / this.rowDimension;
        return Math.sqrt(this.scale.getParameterValue(n2) / 2.0);
    }
}

