/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.rexp.Converter;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.ModelFrameFormulaContext;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RStringVector;

public class HurdleConverter
extends Converter<RGenericVector> {
    private ContinuousLabel label = null;
    private static final String NAME_COUNT = "count";
    private static final String NAME_ZERO = "zero";

    public HurdleConverter(RGenericVector hurdle) {
        super(hurdle);
    }

    @Override
    public PMML encodePMML(RExpEncoder encoder) {
        Model zeroModel = this.encodeModel(NAME_ZERO, encoder);
        OutputField zeroPredictField = ModelUtil.createPredictedField((String)FieldNameUtil.create((String)"predict", (Object[])new Object[]{NAME_ZERO}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE);
        encoder.createDerivedField(zeroModel, zeroPredictField, true);
        Segment zeroSegment = new Segment((Predicate)True.INSTANCE, zeroModel).setId(NAME_ZERO);
        Model countModel = this.encodeModel(NAME_COUNT, encoder);
        OutputField countPredictField = ModelUtil.createPredictedField((String)FieldNameUtil.create((String)"predict", (Object[])new Object[]{NAME_COUNT}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE);
        encoder.createDerivedField(countModel, countPredictField, true);
        Segment countSegment = new Segment((Predicate)True.INSTANCE, countModel).setId(NAME_COUNT);
        Apply adjExpression = ExpressionUtil.createApply((String)"exp", (Expression[])new Expression[]{ExpressionUtil.createApply((String)"-", (Expression[])new Expression[]{ExpressionUtil.createApply((String)"ln", (Expression[])new Expression[]{new FieldRef((Field)zeroPredictField)}), ExpressionUtil.createApply((String)"stats::ppois", (Expression[])new Expression[]{ExpressionUtil.createConstant((Number)0), new FieldRef((Field)countPredictField)})})});
        DerivedField adjZeroPredictField = encoder.createDerivedField(FieldNameUtil.create((String)"adjusted", (Object[])new Object[]{zeroPredictField.requireName()}), OpType.CONTINUOUS, DataType.DOUBLE, (Expression)adjExpression);
        Apply targetExpression = ExpressionUtil.createApply((String)"*", (Expression[])new Expression[]{new FieldRef((Field)countPredictField), new FieldRef((Field)adjZeroPredictField)});
        DerivedField targetField = encoder.createDerivedField(FieldNameUtil.create((String)"adjusted", (Object[])new Object[]{countPredictField.requireName()}), OpType.CONTINUOUS, DataType.DOUBLE, (Expression)targetExpression);
        ContinuousFeature feature = new ContinuousFeature((PMMLEncoder)encoder, (Field)targetField);
        Schema schema = new Schema((ModelEncoder)encoder, (Label)this.label, Collections.emptyList());
        RegressionModel fullModel = RegressionModelUtil.createRegression(Collections.singletonList(feature), Collections.singletonList(1.0), null, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (Schema)schema);
        OutputField truncatedTargetField = new OutputField(FieldNameUtil.create((String)"truncated", (Object[])new Object[]{this.label.getName()}), OpType.CONTINUOUS, DataType.DOUBLE).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression((Expression)new FieldRef((Field)countPredictField));
        Output output = new Output().addOutputFields(new OutputField[]{truncatedTargetField});
        fullModel.setOutput(output);
        Segment fullSegment = new Segment((Predicate)True.INSTANCE, (Model)fullModel).setId("full");
        List<Model> models = Arrays.asList(zeroModel, countModel, fullModel);
        Segmentation segmentation = new Segmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, null).addSegments(new Segment[]{zeroSegment, countSegment, fullSegment});
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, MiningModelUtil.createMiningSchema(models)).setSegmentation(segmentation);
        return encoder.encodePMML((Model)miningModel);
    }

    private Model encodeModel(String name, RExpEncoder encoder) {
        RegressionModel regressionModel;
        RGenericVector hurdle = (RGenericVector)this.getObject();
        RDoubleVector coefficients = hurdle.getGenericElement("coefficients").getDoubleElement(name);
        RStringVector dist = hurdle.getGenericElement("dist").getStringElement(name);
        RExp terms = (RExp)hurdle.getGenericElement("terms").getElement(name);
        RGenericVector model = hurdle.getGenericElement("model");
        RStringVector coefficientNames = coefficients.names();
        ModelFrameFormulaContext context = new ModelFrameFormulaContext(model);
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        switch (name) {
            case "count": {
                FormulaUtil.setLabel(formula, terms, null, encoder);
                ContinuousLabel continuousLabel = (ContinuousLabel)encoder.getLabel();
                DataField dataField = (DataField)encoder.getField(continuousLabel.getName());
                dataField.setDataType(DataType.DOUBLE);
                this.label = new ContinuousLabel((Field)dataField);
                break;
            }
            case "zero": {
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        encoder.setLabel((Label)new ContinuousLabel(DataType.DOUBLE));
        List<Feature> features = encoder.getFeatures();
        if (!features.isEmpty()) {
            features.clear();
        }
        List<String> names = FormulaUtil.removeSpecialSymbol(coefficientNames.getDequotedValues(), "(Intercept)");
        FormulaUtil.addFeatures(formula, names, true, encoder);
        features = encoder.getFeatures();
        Schema schema = encoder.createSchema();
        Double intercept = (Double)coefficients.getElement("(Intercept)", false);
        SchemaUtil.checkSize((int)(coefficients.size() - (intercept != null ? 1 : 0)), features);
        ArrayList<Double> featureCoefficients = new ArrayList<Double>();
        for (Feature feature : features) {
            Double coefficient = formula.getCoefficient(feature, coefficients);
            featureCoefficients.add(coefficient);
        }
        block12 : switch (name) {
            case "zero": {
                switch ((String)dist.asScalar()) {
                    case "binomial": {
                        regressionModel = RegressionModelUtil.createRegression(features, featureCoefficients, (Number)intercept, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (Schema)schema);
                        break block12;
                    }
                }
                throw new IllegalArgumentException();
            }
            case "count": {
                switch ((String)dist.asScalar()) {
                    case "poisson": {
                        regressionModel = RegressionModelUtil.createRegression(features, featureCoefficients, (Number)intercept, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.EXP, (Schema)schema);
                        break block12;
                    }
                }
                throw new IllegalArgumentException();
            }
            default: {
                throw new IllegalArgumentException(name);
            }
        }
        return regressionModel;
    }
}

