/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Random;
import java.util.logging.Logger;

public class CRFTrainerByValueGradients
extends TransducerTrainer
implements TransducerTrainer.ByOptimization {
    private static Logger logger = MalletLogger.getLogger(CRFTrainerByValueGradients.class.getName());
    CRF crf;
    Optimizable.ByGradientValue[] optimizableByValueGradientObjects;
    OptimizableCRF ocrf;
    Optimizer opt;
    int iterationCount = 0;
    boolean converged;
    private int cachedValueWeightsStamp = -1;
    private int cachedGradientWeightsStamp = -1;
    public static final int DEFAULT_MAX_RESETS = 3;
    int maxResets = 3;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;

    public CRFTrainerByValueGradients(CRF crf, Optimizable.ByGradientValue[] optimizableByValueGradientObjects) {
        this.crf = crf;
        this.optimizableByValueGradientObjects = optimizableByValueGradientObjects;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    public CRF getCRF() {
        return this.crf;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.opt;
    }

    public boolean isConverged() {
        return this.converged;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override
    public int getIteration() {
        return this.iterationCount;
    }

    public Optimizable.ByGradientValue[] getOptimizableByGradientValueObjects() {
        return this.optimizableByValueGradientObjects;
    }

    public OptimizableCRF getOptimizableCRF(InstanceList trainingSet) {
        if (this.ocrf == null || this.ocrf.trainingSet != trainingSet) {
            this.ocrf = new OptimizableCRF(this.crf, trainingSet);
            this.opt = null;
        }
        return this.ocrf;
    }

    public Optimizer getOptimizer(InstanceList trainingSet) {
        this.getOptimizableCRF(trainingSet);
        if (this.opt == null || this.ocrf != this.opt.getOptimizable()) {
            this.opt = new LimitedMemoryBFGS(this.ocrf);
        }
        return this.opt;
    }

    public boolean trainIncremental(InstanceList training) {
        return this.train(training, Integer.MAX_VALUE);
    }

    @Override
    public boolean train(InstanceList trainingSet, int numIterations) {
        if (numIterations <= 0) {
            return false;
        }
        assert (trainingSet.size() > 0);
        this.getOptimizableCRF(trainingSet);
        this.getOptimizer(trainingSet);
        int numResets = 0;
        boolean converged = false;
        logger.info("CRF about to train with " + numIterations + " iterations");
        for (int i = 0; i < numIterations; ++i) {
            try {
                long startTime = System.currentTimeMillis();
                converged = this.opt.optimize(1);
                logger.info("CRF finished one iteration of maximizer, i=" + i + ", " + (System.currentTimeMillis() - startTime) / 1000L + " secs.");
                ++this.iterationCount;
                this.runEvaluators();
            }
            catch (OptimizationException e) {
                e.printStackTrace();
                logger.info("Catching exception.");
                if (numResets < this.maxResets) {
                    logger.info("Resetting optimizer.");
                    ++numResets;
                    this.opt = null;
                    this.getOptimizer(trainingSet);
                }
                logger.info("Saying converged.");
                converged = true;
            }
            if (!converged) continue;
            logger.info("CRF training has converged, i=" + i);
            break;
        }
        return converged;
    }

    public boolean train(InstanceList training, int numIterationsPerProportion, double[] trainingProportions) {
        int trainingIteration = 0;
        assert (trainingProportions.length > 0);
        boolean converged = false;
        for (int i = 0; i < trainingProportions.length; ++i) {
            assert (trainingProportions[i] <= 1.0);
            logger.info("Training on " + trainingProportions[i] + "% of the data this round.");
            converged = trainingProportions[i] == 1.0 ? this.train(training, numIterationsPerProportion) : this.train(training.split(new Random(1L), new double[]{trainingProportions[i], 1.0 - trainingProportions[i]})[0], numIterationsPerProportion);
            trainingIteration += numIterationsPerProportion;
        }
        return converged;
    }

    public void setMaxResets(int maxResets) {
        this.maxResets = maxResets;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeInt(this.cachedGradientWeightsStamp);
        out.writeInt(this.cachedValueWeightsStamp);
        throw new IllegalStateException("Implementation not yet complete.");
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.readInt();
        throw new IllegalStateException("Implementation not yet complete.");
    }

    public class OptimizableCRF
    implements Optimizable.ByGradientValue,
    Serializable {
        InstanceList trainingSet;
        double cachedValue = -1.23456789E8;
        double[] cachedGradie;
        BitSet infiniteValues = null;
        CRF crf;
        Optimizable.ByGradientValue[] opts;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 0;

        protected OptimizableCRF(CRF crf, InstanceList ilist) {
            this.crf = crf;
            this.trainingSet = ilist;
            this.opts = CRFTrainerByValueGradients.this.optimizableByValueGradientObjects;
            this.cachedGradie = new double[crf.parameters.getNumFactors()];
            CRFTrainerByValueGradients.this.cachedValueWeightsStamp = -1;
            CRFTrainerByValueGradients.this.cachedGradientWeightsStamp = -1;
        }

        @Override
        public int getNumParameters() {
            return this.crf.parameters.getNumFactors();
        }

        @Override
        public void getParameters(double[] buffer) {
            this.crf.parameters.getParameters(buffer);
        }

        @Override
        public double getParameter(int index) {
            return this.crf.parameters.getParameter(index);
        }

        @Override
        public void setParameters(double[] buff) {
            this.crf.parameters.setParameters(buff);
            this.crf.weightsValueChanged();
        }

        @Override
        public void setParameter(int index, double value) {
            this.crf.parameters.setParameter(index, value);
            this.crf.weightsValueChanged();
        }

        @Override
        public double getValue() {
            if (this.crf.weightsValueChangeStamp != CRFTrainerByValueGradients.this.cachedValueWeightsStamp) {
                long startingTime = System.currentTimeMillis();
                this.cachedValue = 0.0;
                for (int i = 0; i < this.opts.length; ++i) {
                    this.cachedValue += this.opts[i].getValue();
                }
                CRFTrainerByValueGradients.this.cachedValueWeightsStamp = this.crf.weightsValueChangeStamp;
                logger.info("getValue() (loglikelihood) = " + this.cachedValue);
                logger.fine("Inference milliseconds = " + (System.currentTimeMillis() - startingTime));
            }
            return this.cachedValue;
        }

        @Override
        public void getValueGradient(double[] buffer) {
            if (CRFTrainerByValueGradients.this.cachedGradientWeightsStamp != this.crf.weightsValueChangeStamp) {
                this.getValue();
                MatrixOps.setAll(this.cachedGradie, 0.0);
                double[] b2 = new double[buffer.length];
                for (int i = 0; i < this.opts.length; ++i) {
                    MatrixOps.setAll(b2, 0.0);
                    this.opts[i].getValueGradient(b2);
                    MatrixOps.plusEquals(this.cachedGradie, b2);
                }
                CRFTrainerByValueGradients.this.cachedGradientWeightsStamp = this.crf.weightsValueChangeStamp;
            }
            System.arraycopy(this.cachedGradie, 0, buffer, 0, this.cachedGradie.length);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(0);
            out.writeObject(this.trainingSet);
            out.writeDouble(this.cachedValue);
            out.writeObject(this.cachedGradie);
            out.writeObject(this.infiniteValues);
            out.writeObject(this.crf);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.readInt();
            this.trainingSet = (InstanceList)in.readObject();
            this.cachedValue = in.readDouble();
            this.cachedGradie = (double[])in.readObject();
            this.infiniteValues = (BitSet)in.readObject();
            this.crf = (CRF)in.readObject();
        }
    }
}

