package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/optimization/SGDToQNMinimizer.class */
public class SGDToQNMinimizer implements Minimizer<DiffFunction>, Serializable {
    private int k;
    private int bSize;
    private boolean quiet;
    public boolean outputIterationsToFile;
    public int outputFrequency;
    public double gain;
    private List<double[]> gradList;
    private List<double[]> yList;
    private List<double[]> sList;
    private List<double[]> tmpYList;
    private List<double[]> tmpSList;
    private int memory;
    public int SGDPasses;
    public int QNPasses;
    private int hessSampleSize;
    private int QNMem;
    private boolean toTest;
    private static NumberFormat nf = new DecimalFormat("0.000E0");

    public void shutUp() {
        this.quiet = true;
    }

    public void setBatchSize(int i) {
        this.bSize = i;
    }

    public SGDToQNMinimizer(SeqClassifierFlags seqClassifierFlags) {
        this.bSize = 15;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.outputFrequency = 10;
        this.gain = 0.1d;
        this.gradList = null;
        this.yList = null;
        this.sList = null;
        this.tmpYList = null;
        this.tmpSList = null;
        this.memory = 5;
        this.SGDPasses = -1;
        this.QNPasses = -1;
        this.hessSampleSize = 50;
        this.QNMem = 10;
        this.toTest = false;
        this.bSize = seqClassifierFlags.stochasticBatchSize;
        this.gain = seqClassifierFlags.initialGain;
        this.SGDPasses = seqClassifierFlags.SGDPasses;
        this.QNPasses = seqClassifierFlags.QNPasses;
        this.QNMem = seqClassifierFlags.QNsize;
        this.outputIterationsToFile = seqClassifierFlags.outputIterationsToFile;
        this.toTest = seqClassifierFlags.testObjFunction;
        this.hessSampleSize = seqClassifierFlags.SGD2QNhessSamples;
    }

    public SGDToQNMinimizer(double d, int i, int i2, int i3, int i4, int i5) {
        this.bSize = 15;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.outputFrequency = 10;
        this.gain = 0.1d;
        this.gradList = null;
        this.yList = null;
        this.sList = null;
        this.tmpYList = null;
        this.tmpSList = null;
        this.memory = 5;
        this.SGDPasses = -1;
        this.QNPasses = -1;
        this.hessSampleSize = 50;
        this.QNMem = 10;
        this.toTest = false;
        this.bSize = i;
        this.gain = d;
        this.SGDPasses = i2;
        this.QNPasses = i3;
        this.QNMem = i5;
        this.hessSampleSize = i4;
    }

    public SGDToQNMinimizer(double d, int i, int i2, int i3) {
        this(d, i, i2, i3, 50, 10);
    }

    public void setQNMem(int i) {
        this.QNMem = i;
    }

    public void setHessSampleSize(int i) {
        this.hessSampleSize = i;
    }

    private double gainSchedule(int i, double d) {
        return d / (d + i);
    }

    private double[] smooth(List<double[]> list) {
        double[] dArr = new double[list.get(0).length];
        Iterator<double[]> it = list.iterator();
        while (it.hasNext()) {
            ArrayMath.pairwiseAddInPlace(dArr, it.next());
        }
        ArrayMath.multiplyInPlace(dArr, 1.0d / list.size());
        return dArr;
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr) {
        return minimize(diffFunction, d, dArr, -1);
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr, int i) {
        sayln("SGDToQNMinimizer called on function of " + diffFunction.domainDimension() + " variables;");
        if (!(diffFunction instanceof AbstractStochasticCachingDiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction = (AbstractStochasticCachingDiffFunction) diffFunction;
        if (this.toTest) {
            if (new StochasticDiffFunctionTester(abstractStochasticCachingDiffFunction).testSumOfBatches(dArr, 1.0E-4d)) {
                System.err.println("Testing complete... exiting");
                System.exit(1);
            } else {
                System.err.println("Testing failed....exiting");
                System.exit(1);
            }
        }
        abstractStochasticCachingDiffFunction.method = StochasticCalculateMethods.GradientOnly;
        PrintWriter printWriter = null;
        PrintWriter printWriter2 = null;
        if (this.outputIterationsToFile) {
            try {
                printWriter = new PrintWriter((OutputStream) new FileOutputStream("SGD2QN" + this.bSize + "_g" + (this.gain * 1000.0d) + ".output"), true);
                printWriter2 = new PrintWriter((OutputStream) new FileOutputStream("SGD2QN" + this.bSize + "_g" + (this.gain * 1000.0d) + ".info"), true);
            } catch (IOException e) {
                System.err.println("Caught IOException outputing SGD data to file: " + e.getMessage());
                System.exit(1);
            }
        }
        double[] dArr2 = dArr;
        double[] dArr3 = new double[dArr2.length];
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        double[] dArr6 = new double[dArr2.length];
        this.gradList = new ArrayList();
        this.yList = new ArrayList();
        this.sList = new ArrayList();
        this.tmpYList = new ArrayList();
        this.tmpSList = new ArrayList();
        if (this.SGDPasses < 0 || this.QNPasses < 0) {
            System.err.println("Error:  Number of iterations not set");
            System.exit(1);
        }
        int dataDimension = abstractStochasticCachingDiffFunction.dataDimension() / this.bSize;
        int i2 = (dataDimension * this.SGDPasses) - (this.hessSampleSize * this.QNMem);
        sayln("       Batchsize of: " + this.bSize);
        sayln("       Batches per pass through data:  " + dataDimension);
        sayln("       Passes of SGD: " + this.SGDPasses);
        sayln("       Passes of QN:  " + this.QNPasses);
        sayln("       Hess sample size:   " + this.hessSampleSize);
        sayln("       QNSize:   " + this.QNMem);
        sayln("       ");
        if (i2 < 0) {
            System.err.println("Not enough data for mem and smoothing settings");
            System.exit(1);
        }
        int i3 = dataDimension * this.SGDPasses;
        if (this.outputIterationsToFile) {
            printWriter2.println(abstractStochasticCachingDiffFunction.domainDimension() + "; DomainDimension ");
            printWriter2.println(this.bSize + "; batchSize ");
            printWriter2.println(this.SGDPasses + "; SGDPasses");
            printWriter2.println(this.QNPasses + "; QNPasses");
            printWriter2.println(dataDimension + "; numBatches ");
            printWriter2.println(this.hessSampleSize + ";hessSampleSize");
            printWriter2.println(this.QNMem + ";QNSize");
            printWriter2.println(this.outputFrequency + "; outputFrequency");
        }
        sayln("Iter: n ++TimeforLastIt++\n");
        long currentTimeMillis = System.currentTimeMillis();
        this.k = 0;
        while (true) {
            if (this.k >= i3) {
                break;
            }
            double gainSchedule = this.gain * gainSchedule(this.k, 5 * dataDimension);
            say("Iter: " + this.k + "  ++ " + (System.currentTimeMillis() - currentTimeMillis) + " ms++  ");
            currentTimeMillis = System.currentTimeMillis();
            double[] remove = (this.k <= 0 || this.gradList.size() < this.memory) ? new double[dArr3.length] : this.gradList.remove(0);
            if (this.k > i2) {
                double[] dArr7 = new double[dArr3.length];
                abstractStochasticCachingDiffFunction.method = StochasticCalculateMethods.IncorporatedFiniteDifference;
                abstractStochasticCachingDiffFunction.hasNewVals = true;
                System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr2, dArr2, this.bSize), 0, remove, 0, remove.length);
                abstractStochasticCachingDiffFunction.returnPreviousValues = true;
                System.arraycopy(abstractStochasticCachingDiffFunction.HdotVAt(dArr2, dArr2, this.bSize), 0, dArr7, 0, dArr3.length);
                this.tmpYList.add(dArr7);
                this.tmpSList.add(dArr5);
                if (this.tmpYList.size() == this.hessSampleSize) {
                    this.yList.add(smooth(this.tmpYList));
                    this.sList.add(smooth(this.tmpSList));
                    this.tmpYList.clear();
                    this.tmpSList.clear();
                }
            } else {
                abstractStochasticCachingDiffFunction.hasNewVals = true;
                System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr2, this.bSize), 0, remove, 0, remove.length);
            }
            this.gradList.add(remove);
            dArr3 = smooth(this.gradList);
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                dArr5[i4] = dArr2[i4] - (gainSchedule * dArr3[i4]);
            }
            if (this.k >= i3) {
                sayln("SGD completed.");
                dArr2 = dArr5;
                break;
            }
            if (this.outputIterationsToFile && this.k % this.outputFrequency == 0) {
                long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                double valueAt = abstractStochasticCachingDiffFunction.valueAt(dArr2);
                say(" TrueValue{ " + valueAt + " } ");
                printWriter.println(this.k + " , " + valueAt + " , " + gainSchedule + " , " + currentTimeMillis2 + " , " + ArrayMath.norm(ArrayMath.pairwiseSubtract(dArr2, dArr5)));
            }
            dArr2 = dArr5;
            if (this.quiet) {
                System.err.print(".");
            } else {
                sayln("");
            }
            this.k++;
        }
        System.err.println("Passing off to QN");
        System.err.println("");
        QNMinimizer qNMinimizer = new QNMinimizer(this.QNMem);
        qNMinimizer.setHistory(this.sList, this.yList);
        double[] minimize = qNMinimizer.minimize((DiffFunction) abstractStochasticCachingDiffFunction, d, dArr2, this.QNPasses);
        if (this.outputIterationsToFile) {
            printWriter2.close();
            printWriter.close();
            System.err.println("Output Files Closed");
        }
        System.err.println("");
        System.err.println("QN Minimization complete.");
        System.err.println("");
        return minimize;
    }

    private void sayln(String str) {
        if (this.quiet) {
            return;
        }
        System.err.println(str);
    }

    private void say(String str) {
        if (this.quiet) {
            return;
        }
        System.err.print(str);
    }
}
