/*
 * Decompiled with CFR 0.152.
 */
package weka.attributeSelection;

import java.util.BitSet;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.SubsetEvaluator;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.AbstractEvaluationMetric;
import weka.classifiers.evaluation.InformationRetrievalEvaluationMetric;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class WrapperSubsetEval
extends ASEvaluation
implements SubsetEvaluator,
OptionHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -4573057658746728675L;
    private Instances m_trainInstances;
    private int m_classIndex;
    private int m_numAttribs;
    private Evaluation m_Evaluation;
    private Classifier m_BaseClassifier;
    private int m_folds;
    private int m_seed;
    private double m_threshold;
    public static final int EVAL_DEFAULT = 1;
    public static final int EVAL_ACCURACY = 2;
    public static final int EVAL_RMSE = 3;
    public static final int EVAL_MAE = 4;
    public static final int EVAL_FMEASURE = 5;
    public static final int EVAL_AUC = 6;
    public static final int EVAL_AUPRC = 7;
    public static final int EVAL_CORRELATION = 8;
    public static final int EVAL_PLUGIN = 9;
    public static final Tag[] TAGS_EVALUATION;
    protected int m_IRClassVal = -1;
    protected String m_IRClassValS = "";
    protected static List<AbstractEvaluationMetric> PLUGIN_METRICS;
    protected Tag m_evaluationMeasure = TAGS_EVALUATION[0];

    public String globalInfo() {
        return "WrapperSubsetEval:\n\nEvaluates attribute sets by using a learning scheme. Cross validation is used to estimate the accuracy of the learning scheme for a set of attributes.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Ron Kohavi and George H. John");
        result.setValue(TechnicalInformation.Field.YEAR, "1997");
        result.setValue(TechnicalInformation.Field.TITLE, "Wrappers for feature subset selection");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Artificial Intelligence");
        result.setValue(TechnicalInformation.Field.VOLUME, "97");
        result.setValue(TechnicalInformation.Field.NUMBER, "1-2");
        result.setValue(TechnicalInformation.Field.PAGES, "273-324");
        result.setValue(TechnicalInformation.Field.NOTE, "Special issue on relevance");
        result.setValue(TechnicalInformation.Field.ISSN, "0004-3702");
        return result;
    }

    public WrapperSubsetEval() {
        this.resetOptions();
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(4);
        newVector.addElement(new Option("\tclass name of base learner to use for \taccuracy estimation.\n\tPlace any classifier options LAST on the command line\n\tfollowing a \"--\". eg.:\n\t\t-B weka.classifiers.bayes.NaiveBayes ... -- -K\n\t(default: weka.classifiers.rules.ZeroR)", "B", 1, "-B <base learner>"));
        newVector.addElement(new Option("\tnumber of cross validation folds to use for estimating accuracy.\n\t(default=5)", "F", 1, "-F <num>"));
        newVector.addElement(new Option("\tSeed for cross validation accuracy testimation.\n\t(default = 1)", "R", 1, "-R <seed>"));
        newVector.addElement(new Option("\tthreshold by which to execute another cross validation\n\t(standard deviation---expressed as a percentage of the mean).\n\t(default: 0.01 (1%))", "T", 1, "-T <num>"));
        newVector.addElement(new Option("\tPerformance evaluation measure to use for selecting attributes.\n\t(Default = default: accuracy for discrete class and rmse for numeric class)", "E", 1, "-E " + Tag.toOptionList(TAGS_EVALUATION)));
        newVector.addElement(new Option("\tOptional class value (label or 1-based index) to use in conjunction with\n\tIR statistics (f-meas, auc or auprc). Omitting this option will use\n\tthe class-weighted average.", "IRclass", 1, "-IRclass <label | index>"));
        if (this.m_BaseClassifier != null && this.m_BaseClassifier instanceof OptionHandler) {
            newVector.addElement(new Option("", "", 0, "\nOptions specific to scheme " + this.m_BaseClassifier.getClass().getName() + ":"));
            newVector.addAll(Collections.list(((OptionHandler)((Object)this.m_BaseClassifier)).listOptions()));
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.resetOptions();
        String optionString = Utils.getOption('B', options);
        if (optionString.length() == 0) {
            optionString = ZeroR.class.getName();
        }
        this.setClassifier(AbstractClassifier.forName(optionString, Utils.partitionOptions(options)));
        optionString = Utils.getOption('F', options);
        if (optionString.length() != 0) {
            this.setFolds(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption('R', options)).length() != 0) {
            this.setSeed(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption('T', options)).length() != 0) {
            Double temp = Double.valueOf(optionString);
            this.setThreshold(temp);
        }
        if ((optionString = Utils.getOption('E', options)).length() != 0) {
            for (Tag t : TAGS_EVALUATION) {
                if (!t.getIDStr().equalsIgnoreCase(optionString)) continue;
                this.setEvaluationMeasure(new SelectedTag(t.getIDStr(), TAGS_EVALUATION));
                break;
            }
        }
        if ((optionString = Utils.getOption("IRClass", options)).length() > 0) {
            this.setIRClassValue(optionString);
        }
    }

    public void setIRClassValue(String val) {
        this.m_IRClassValS = val;
    }

    public String getIRClassValue() {
        return this.m_IRClassValS;
    }

    public String IRClassValueTipText() {
        return "The class label, or 1-based index of the class label, to use when evaluating subsets with an IR metric (such as f-measure or AUC. Leaving this unset will result in the class frequency weighted average of the metric being used.";
    }

    public String evaluationMeasureTipText() {
        return "The measure used to evaluate the performance of attribute combinations.";
    }

    public SelectedTag getEvaluationMeasure() {
        return new SelectedTag(this.m_evaluationMeasure.getIDStr(), TAGS_EVALUATION);
    }

    public void setEvaluationMeasure(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_EVALUATION) {
            this.m_evaluationMeasure = newMethod.getSelectedTag();
        }
    }

    public String thresholdTipText() {
        return "Repeat xval if stdev of mean exceeds this value.";
    }

    public void setThreshold(double t) {
        this.m_threshold = t;
    }

    public double getThreshold() {
        return this.m_threshold;
    }

    public String foldsTipText() {
        return "Number of xval folds to use when estimating subset accuracy.";
    }

    public void setFolds(int f) {
        this.m_folds = f;
    }

    public int getFolds() {
        return this.m_folds;
    }

    public String seedTipText() {
        return "Seed to use for randomly generating xval splits.";
    }

    public void setSeed(int s) {
        this.m_seed = s;
    }

    public int getSeed() {
        return this.m_seed;
    }

    public String classifierTipText() {
        return "Classifier to use for estimating the accuracy of subsets";
    }

    public void setClassifier(Classifier newClassifier) {
        this.m_BaseClassifier = newClassifier;
    }

    public Classifier getClassifier() {
        return this.m_BaseClassifier;
    }

    @Override
    public String[] getOptions() {
        String[] classifierOptions = new String[]{};
        if (this.m_BaseClassifier != null && this.m_BaseClassifier instanceof OptionHandler) {
            classifierOptions = ((OptionHandler)((Object)this.m_BaseClassifier)).getOptions();
        }
        String[] options = new String[13 + classifierOptions.length];
        int current = 0;
        if (this.getClassifier() != null) {
            options[current++] = "-B";
            options[current++] = this.getClassifier().getClass().getName();
        }
        options[current++] = "-F";
        options[current++] = "" + this.getFolds();
        options[current++] = "-T";
        options[current++] = "" + this.getThreshold();
        options[current++] = "-R";
        options[current++] = "" + this.getSeed();
        options[current++] = "-E";
        options[current++] = this.m_evaluationMeasure.getIDStr();
        if (this.m_IRClassValS != null && this.m_IRClassValS.length() > 0) {
            options[current++] = "-IRClass";
            options[current++] = this.m_IRClassValS;
        }
        options[current++] = "--";
        System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length);
        current += classifierOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    protected void resetOptions() {
        this.m_trainInstances = null;
        this.m_Evaluation = null;
        this.m_BaseClassifier = new ZeroR();
        this.m_folds = 5;
        this.m_seed = 1;
        this.m_threshold = 0.01;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result;
        if (this.getClassifier() == null) {
            result = super.getCapabilities();
            result.disableAll();
        } else {
            result = this.getClassifier().getCapabilities();
        }
        for (Capabilities.Capability cap : Capabilities.Capability.values()) {
            result.enableDependency(cap);
        }
        result.disable(Capabilities.Capability.NUMERIC_CLASS);
        result.disable(Capabilities.Capability.DATE_CLASS);
        boolean pluginMetricNominalClass = false;
        if (this.m_evaluationMeasure.getID() >= 9) {
            String metricName = ((PluginTag)this.m_evaluationMeasure).getMetricName();
            for (AbstractEvaluationMetric m : PLUGIN_METRICS) {
                if (!m.getMetricName().equals(metricName)) continue;
                pluginMetricNominalClass = m.appliesToNominalClass();
                break;
            }
        }
        if (this.m_evaluationMeasure.getID() != 2 && this.m_evaluationMeasure.getID() != 5 && this.m_evaluationMeasure.getID() != 6 && this.m_evaluationMeasure.getID() != 7 && !pluginMetricNominalClass) {
            result.enable(Capabilities.Capability.NUMERIC_CLASS);
            result.enable(Capabilities.Capability.DATE_CLASS);
        }
        result.setMinimumNumberInstances(this.getFolds());
        return result;
    }

    @Override
    public void buildEvaluator(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_trainInstances = data;
        this.m_classIndex = this.m_trainInstances.classIndex();
        this.m_numAttribs = this.m_trainInstances.numAttributes();
        if (this.m_IRClassValS != null && this.m_IRClassValS.length() > 0) {
            try {
                this.m_IRClassVal = Integer.parseInt(this.m_IRClassValS);
                --this.m_IRClassVal;
            }
            catch (NumberFormatException e) {
                this.m_IRClassVal = this.m_trainInstances.classAttribute().indexOfValue(this.m_IRClassValS);
            }
        }
    }

    @Override
    public double evaluateSubset(BitSet subset) throws Exception {
        int i;
        double evalMetric = 0.0;
        double[] repError = new double[5];
        int numAttributes = 0;
        Random Rnd = new Random(this.m_seed);
        Remove delTransform = new Remove();
        delTransform.setInvertSelection(true);
        Instances trainCopy = new Instances(this.m_trainInstances);
        for (i = 0; i < this.m_numAttribs; ++i) {
            if (!subset.get(i)) continue;
            ++numAttributes;
        }
        int[] featArray = new int[numAttributes + 1];
        int j = 0;
        for (i = 0; i < this.m_numAttribs; ++i) {
            if (!subset.get(i)) continue;
            featArray[j++] = i;
        }
        featArray[j] = this.m_classIndex;
        delTransform.setAttributeIndicesArray(featArray);
        delTransform.setInputFormat(trainCopy);
        trainCopy = Filter.useFilter(trainCopy, delTransform);
        AbstractEvaluationMetric pluginMetric = null;
        String statName = null;
        String metricName = null;
        for (i = 0; i < 5; ++i) {
            this.m_Evaluation = new Evaluation(trainCopy);
            this.m_Evaluation.crossValidateModel(this.m_BaseClassifier, trainCopy, this.m_folds, Rnd);
            switch (this.m_evaluationMeasure.getID()) {
                case 1: {
                    repError[i] = this.m_Evaluation.errorRate();
                    break;
                }
                case 2: {
                    repError[i] = this.m_Evaluation.errorRate();
                    break;
                }
                case 3: {
                    repError[i] = this.m_Evaluation.rootMeanSquaredError();
                    break;
                }
                case 4: {
                    repError[i] = this.m_Evaluation.meanAbsoluteError();
                    break;
                }
                case 5: {
                    if (this.m_IRClassVal < 0) {
                        repError[i] = this.m_Evaluation.weightedFMeasure();
                        break;
                    }
                    repError[i] = this.m_Evaluation.fMeasure(this.m_IRClassVal);
                    break;
                }
                case 6: {
                    if (this.m_IRClassVal < 0) {
                        repError[i] = this.m_Evaluation.weightedAreaUnderROC();
                        break;
                    }
                    repError[i] = this.m_Evaluation.areaUnderROC(this.m_IRClassVal);
                    break;
                }
                case 7: {
                    if (this.m_IRClassVal < 0) {
                        repError[i] = this.m_Evaluation.weightedAreaUnderPRC();
                        break;
                    }
                    repError[i] = this.m_Evaluation.areaUnderPRC(this.m_IRClassVal);
                    break;
                }
                case 8: {
                    repError[i] = this.m_Evaluation.correlationCoefficient();
                    break;
                }
                default: {
                    if (this.m_evaluationMeasure.getID() >= 9) {
                        metricName = ((PluginTag)this.m_evaluationMeasure).getMetricName();
                        statName = ((PluginTag)this.m_evaluationMeasure).getStatisticName();
                        statName = ((PluginTag)this.m_evaluationMeasure).getStatisticName();
                        pluginMetric = this.m_Evaluation.getPluginMetric(metricName);
                        if (pluginMetric == null) {
                            throw new Exception("Metric  " + metricName + " does not seem to be available");
                        }
                    }
                    if (pluginMetric instanceof InformationRetrievalEvaluationMetric) {
                        if (this.m_IRClassVal < 0) {
                            repError[i] = ((InformationRetrievalEvaluationMetric)((Object)pluginMetric)).getClassWeightedAverageStatistic(statName);
                            break;
                        }
                        repError[i] = ((InformationRetrievalEvaluationMetric)((Object)pluginMetric)).getStatistic(statName, this.m_IRClassVal);
                        break;
                    }
                    repError[i] = pluginMetric.getStatistic(statName);
                }
            }
            if (this.repeat(repError, i + 1)) continue;
            ++i;
            break;
        }
        for (j = 0; j < i; ++j) {
            evalMetric += repError[j];
        }
        evalMetric /= (double)i;
        this.m_Evaluation = null;
        switch (this.m_evaluationMeasure.getID()) {
            case 1: 
            case 2: 
            case 3: 
            case 4: {
                if (this.m_trainInstances.classAttribute().isNominal() && (this.m_evaluationMeasure.getID() == 1 || this.m_evaluationMeasure.getID() == 2)) {
                    evalMetric = 1.0 - evalMetric;
                    break;
                }
                evalMetric = -evalMetric;
                break;
            }
            default: {
                if (pluginMetric == null || pluginMetric.statisticIsMaximisable(statName)) break;
                evalMetric = -evalMetric;
            }
        }
        return evalMetric;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_trainInstances == null) {
            text.append("\tWrapper subset evaluator has not been built yet\n");
        } else {
            text.append("\tWrapper Subset Evaluator\n");
            text.append("\tLearning scheme: " + this.getClassifier().getClass().getName() + "\n");
            text.append("\tScheme options: ");
            String[] classifierOptions = new String[]{};
            if (this.m_BaseClassifier instanceof OptionHandler) {
                for (String classifierOption : classifierOptions = ((OptionHandler)((Object)this.m_BaseClassifier)).getOptions()) {
                    text.append(classifierOption + " ");
                }
            }
            text.append("\n");
            String IRClassL = "";
            if (this.m_IRClassVal >= 0) {
                IRClassL = "(class value: " + this.m_trainInstances.classAttribute().value(this.m_IRClassVal) + ")";
            }
            switch (this.m_evaluationMeasure.getID()) {
                case 1: 
                case 2: {
                    if (this.m_trainInstances.attribute(this.m_classIndex).isNumeric()) {
                        text.append("\tSubset evaluation: RMSE\n");
                        break;
                    }
                    text.append("\tSubset evaluation: classification accuracy\n");
                    break;
                }
                case 3: {
                    if (this.m_trainInstances.attribute(this.m_classIndex).isNumeric()) {
                        text.append("\tSubset evaluation: RMSE\n");
                        break;
                    }
                    text.append("\tSubset evaluation: RMSE (probability estimates)\n");
                    break;
                }
                case 4: {
                    if (this.m_trainInstances.attribute(this.m_classIndex).isNumeric()) {
                        text.append("\tSubset evaluation: MAE\n");
                        break;
                    }
                    text.append("\tSubset evaluation: MAE (probability estimates)\n");
                    break;
                }
                case 5: {
                    text.append("\tSubset evaluation: F-measure " + (this.m_IRClassVal >= 0 ? IRClassL : "") + "\n");
                    break;
                }
                case 6: {
                    text.append("\tSubset evaluation: area under the ROC curve " + (this.m_IRClassVal >= 0 ? IRClassL : "") + "\n");
                    break;
                }
                case 7: {
                    text.append("\tSubset evaluation: area under the precision-recall curve " + (this.m_IRClassVal >= 0 ? IRClassL : "") + "\n");
                    break;
                }
                case 8: {
                    text.append("\tSubset evaluation: correlation coefficient\n");
                    break;
                }
                default: {
                    text.append("\tSubset evaluation: " + this.m_evaluationMeasure.getReadable());
                    if (((PluginTag)this.m_evaluationMeasure).getMetric() instanceof InformationRetrievalEvaluationMetric) {
                        text.append(" " + (this.m_IRClassVal > 0 ? IRClassL : ""));
                    }
                    text.append("\n");
                }
            }
            text.append("\tNumber of folds for accuracy estimation: " + this.m_folds + "\n");
        }
        return text.toString();
    }

    private boolean repeat(double[] repError, int entries) {
        int i;
        double mean = 0.0;
        double variance = 0.0;
        if (this.m_threshold < 0.0) {
            return false;
        }
        if (entries == 1) {
            return true;
        }
        for (i = 0; i < entries; ++i) {
            mean += repError[i];
        }
        mean /= (double)entries;
        for (i = 0; i < entries; ++i) {
            variance += (repError[i] - mean) * (repError[i] - mean);
        }
        if ((variance /= (double)entries) > 0.0) {
            variance = Math.sqrt(variance);
        }
        return variance / mean > this.m_threshold;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 12170 $");
    }

    @Override
    public void clean() {
        this.m_trainInstances = new Instances(this.m_trainInstances, 0);
    }

    public static void main(String[] args) {
        WrapperSubsetEval.runEvaluator(new WrapperSubsetEval(), args);
    }

    static {
        PLUGIN_METRICS = AbstractEvaluationMetric.getPluginMetrics();
        int totalPluginCount = 0;
        if (PLUGIN_METRICS != null) {
            for (AbstractEvaluationMetric m : PLUGIN_METRICS) {
                totalPluginCount += m.getStatisticNames().size();
            }
        }
        TAGS_EVALUATION = new Tag[8 + totalPluginCount];
        WrapperSubsetEval.TAGS_EVALUATION[0] = new Tag(1, "default", "Default: accuracy (discrete class); RMSE (numeric class)");
        WrapperSubsetEval.TAGS_EVALUATION[1] = new Tag(2, "acc", "Accuracy (discrete class only)");
        WrapperSubsetEval.TAGS_EVALUATION[2] = new Tag(3, "rmse", "RMSE (of the class probabilities for discrete class)");
        WrapperSubsetEval.TAGS_EVALUATION[3] = new Tag(4, "mae", "MAE (of the class probabilities for discrete class)");
        WrapperSubsetEval.TAGS_EVALUATION[4] = new Tag(5, "f-meas", "F-measure (discrete class only)");
        WrapperSubsetEval.TAGS_EVALUATION[5] = new Tag(6, "auc", "AUC (area under the ROC curve - discrete class only)");
        WrapperSubsetEval.TAGS_EVALUATION[6] = new Tag(7, "auprc", "AUPRC (area under the precision-recall curve - discrete class only)");
        WrapperSubsetEval.TAGS_EVALUATION[7] = new Tag(8, "corr-coeff", "Correlation coefficient - numeric class only");
        if (PLUGIN_METRICS != null) {
            int index = 8;
            for (AbstractEvaluationMetric m : PLUGIN_METRICS) {
                for (String stat : m.getStatisticNames()) {
                    WrapperSubsetEval.TAGS_EVALUATION[index++] = new PluginTag(index + 1, m, stat);
                }
            }
        }
    }

    protected static class PluginTag
    extends Tag {
        private static final long serialVersionUID = -6978438858413428382L;
        protected AbstractEvaluationMetric m_metric;
        protected String m_statisticName;

        public PluginTag(int ID, AbstractEvaluationMetric metric, String statisticName) {
            super(ID, statisticName, statisticName);
            this.m_metric = metric;
            this.m_statisticName = statisticName;
        }

        public String getMetricName() {
            return this.m_metric.getMetricName();
        }

        public String getStatisticName() {
            return this.m_statisticName;
        }

        public AbstractEvaluationMetric getMetric() {
            return this.m_metric;
        }
    }
}

