package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/optimization/SMDMinimizer.class */
public class SMDMinimizer implements Minimizer {
    private int k;
    private int bSize;
    private boolean quiet;
    private StochasticCalculateMethods methodSMD;
    public boolean outputIterationsToFile;
    public int outputFrequency;
    public double initialGain;
    public boolean restrictSteps;
    public boolean useAlgorithmicDifferentiation;
    PrintWriter file;
    PrintWriter infoFile;
    public double mu;
    public double lam;
    public double cPosDef;
    public boolean testObjFunc;
    public boolean printMinMax;
    public double finalVal;
    public boolean useGaussNewton;
    private static NumberFormat nf = new DecimalFormat("0.000E0");

    public void shutUp() {
        this.quiet = true;
    }

    public void setBatchSize(int i) {
        this.bSize = i;
    }

    public SMDMinimizer() {
        this.bSize = 15;
        this.quiet = false;
        this.methodSMD = StochasticCalculateMethods.NoneSpecified;
        this.outputIterationsToFile = false;
        this.outputFrequency = 10;
        this.initialGain = 0.1d;
        this.restrictSteps = false;
        this.useAlgorithmicDifferentiation = true;
        this.file = null;
        this.infoFile = null;
        this.mu = 0.01d;
        this.lam = 1.0d;
        this.cPosDef = 0.01d;
        this.testObjFunc = false;
        this.printMinMax = false;
        this.finalVal = 0.0d;
        this.useGaussNewton = false;
    }

    public SMDMinimizer(double d, int i, StochasticCalculateMethods stochasticCalculateMethods) {
        this.bSize = 15;
        this.quiet = false;
        this.methodSMD = StochasticCalculateMethods.NoneSpecified;
        this.outputIterationsToFile = false;
        this.outputFrequency = 10;
        this.initialGain = 0.1d;
        this.restrictSteps = false;
        this.useAlgorithmicDifferentiation = true;
        this.file = null;
        this.infoFile = null;
        this.mu = 0.01d;
        this.lam = 1.0d;
        this.cPosDef = 0.01d;
        this.testObjFunc = false;
        this.printMinMax = false;
        this.finalVal = 0.0d;
        this.useGaussNewton = false;
        this.bSize = i;
        this.initialGain = d;
        this.methodSMD = stochasticCalculateMethods;
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(Function function, double d, double[] dArr) {
        return minimize(function, d, dArr, -1);
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(Function function, double d, double[] dArr, int i) {
        sayln("SMDMinimizer called on function of " + function.domainDimension() + " variables;");
        if (!(function instanceof DiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction = (AbstractStochasticCachingDiffFunction) function;
        if (this.methodSMD.equals(StochasticCalculateMethods.NoneSpecified)) {
            System.err.println("No method has been set for the Stochastic Calculation");
            System.exit(1);
        }
        abstractStochasticCachingDiffFunction.method = this.methodSMD;
        if (this.testObjFunc) {
            ArrayMath.addInPlace(dArr, 0.1d);
            testObjectiveFunction(function, dArr, d);
            System.exit(1);
        }
        double[] dArr2 = dArr;
        double[] dArr3 = new double[dArr2.length];
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        double[] dArr6 = new double[dArr2.length];
        double[] dArr7 = new double[dArr2.length];
        double[] dArr8 = new double[dArr2.length];
        double[] dArr9 = new double[dArr2.length];
        System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr2, dArr4, this.bSize), 0, dArr5, 0, dArr5.length);
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            dArr3[i2] = this.initialGain;
            int i3 = i2;
            dArr2[i3] = dArr2[i3] - (this.initialGain * dArr5[i2]);
            dArr4[i2] = (-this.initialGain) * dArr5[i2];
        }
        long currentTimeMillis = System.currentTimeMillis();
        int dataDimension = abstractStochasticCachingDiffFunction.dataDimension() / this.bSize;
        boolean z = i > 0;
        int i4 = i <= 0 ? 100 * dataDimension : i * dataDimension;
        sayln("       Batchsize of: " + this.bSize);
        sayln("       Batches per pass through data:  " + dataDimension);
        sayln("       Max iterations is = " + i4);
        sayln("       All calculations will be made using " + abstractStochasticCachingDiffFunction.method.toString());
        if (this.outputIterationsToFile) {
            outputInitializeFiles();
            this.infoFile.println(function.domainDimension() + "; DomainDimension ");
            this.infoFile.println(this.bSize + "; batchSize ");
            this.infoFile.println(this.cPosDef + "; c");
            this.infoFile.println(this.lam + "; lambda ");
            this.infoFile.println(this.initialGain + "; initGain");
            this.infoFile.println(i4 + "; maxIterations");
            this.infoFile.println(dataDimension + "; numBatches ");
            this.infoFile.println(abstractStochasticCachingDiffFunction.method.toString() + "; calculationMethod");
            this.infoFile.println(this.outputFrequency + "; outputFrequency");
        }
        sayln("Iter: n [time ms] Value \n");
        System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr2, dArr4, this.bSize), 0, dArr5, 0, dArr5.length);
        abstractStochasticCachingDiffFunction.returnPreviousValues = true;
        System.arraycopy(abstractStochasticCachingDiffFunction.HdotVAt(dArr2, dArr4, dArr5, this.bSize), 0, dArr8, 0, dArr8.length);
        long currentTimeMillis2 = System.currentTimeMillis();
        this.k = 0;
        while (true) {
            if (this.k >= i4) {
                break;
            }
            say("Iter: " + this.k + " [ " + (System.currentTimeMillis() - currentTimeMillis) + " ms]  ");
            currentTimeMillis = System.currentTimeMillis();
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                double d2 = 1.0d - ((this.mu * dArr5[i5]) * dArr4[i5]);
                int i6 = i5;
                dArr3[i6] = dArr3[i6] * (d2 < 0.5d ? 0.5d : d2);
                if (dArr3[i5] > 5.0d) {
                    dArr3[i5] = 5.0d;
                }
                dArr4[i5] = ((this.lam * (1.0d + (this.cPosDef * dArr3[i5]))) * dArr4[i5]) - (dArr3[i5] * (dArr5[i5] + (this.lam * dArr8[i5])));
                dArr7[i5] = dArr2[i5] - (dArr3[i5] * dArr5[i5]);
            }
            if (this.printMinMax) {
                say("vMin = " + ArrayMath.min(dArr4) + "  ");
                say("vMax = " + ArrayMath.max(dArr4) + "  ");
                say("gainMin = " + ArrayMath.min(dArr3) + "  ");
                say("gainMax = " + ArrayMath.max(dArr3) + "  ");
            }
            abstractStochasticCachingDiffFunction.hasNewVals = true;
            System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr7, dArr4, this.bSize), 0, dArr6, 0, dArr6.length);
            double d3 = abstractStochasticCachingDiffFunction.value;
            abstractStochasticCachingDiffFunction.returnPreviousValues = true;
            System.arraycopy(abstractStochasticCachingDiffFunction.HdotVAt(dArr7, dArr4, dArr6, this.bSize), 0, dArr8, 0, dArr8.length);
            if (this.useGaussNewton) {
                double[] multiply = ArrayMath.multiply(dArr6, ArrayMath.innerProduct(dArr6, dArr4));
                say(nf.format(ArrayMath.norm(ArrayMath.pairwiseSubtract(dArr8, multiply))) + "  G-H ");
                say(nf.format(d3));
                if (d3 < 5.0d) {
                    dArr8 = multiply;
                } else {
                    System.err.print(" switched ");
                }
            }
            if (z && this.k >= i4) {
                sayln("    Tried to Stop");
                dArr2 = dArr7;
                break;
            }
            if (this.outputIterationsToFile && this.k % this.outputFrequency == 0) {
                long currentTimeMillis3 = System.currentTimeMillis() - currentTimeMillis2;
                double valueAt = abstractStochasticCachingDiffFunction.valueAt(dArr2);
                say(" TrueValue{ " + valueAt + " } ");
                this.file.println(this.k + " , " + valueAt + " , " + d3 + " , " + currentTimeMillis3 + " , " + ArrayMath.norm(ArrayMath.pairwiseSubtract(dArr2, dArr7)));
                currentTimeMillis2 = System.currentTimeMillis();
            }
            sayln("");
            System.arraycopy(dArr6, 0, dArr5, 0, dArr6.length);
            dArr2 = dArr7;
            if (this.quiet) {
                System.err.print(".");
            }
            this.k++;
        }
        if (this.outputIterationsToFile) {
            this.finalVal = abstractStochasticCachingDiffFunction.valueAt(dArr2);
            this.infoFile.println(" iterationCount ; " + this.k);
            this.infoFile.println(" finalValue ; " + this.finalVal);
            this.infoFile.println(this.finalVal);
            this.infoFile.close();
            this.file.close();
            System.err.println("Output Files Closed");
        }
        return dArr2;
    }

    private void outputInitializeFiles() {
        if (this.outputIterationsToFile) {
            try {
                this.file = new PrintWriter((OutputStream) new FileOutputStream("SMD_" + this.methodSMD.toString() + "_mu" + (this.mu * 1000.0d) + "_lam" + (this.lam * 1000.0d) + "_b" + this.bSize + "_g" + (this.initialGain * 1000.0d) + ".output"), true);
                this.infoFile = new PrintWriter((OutputStream) new FileOutputStream("SMD_" + this.methodSMD.toString() + "_mu" + (this.mu * 1000.0d) + "_lam" + (this.lam * 1000.0d) + "_b" + this.bSize + "_g" + (this.initialGain * 1000.0d) + ".info"), true);
            } catch (IOException e) {
                System.err.println("Caught IOException outputing SMD data to file: " + e.getMessage());
                System.exit(1);
            }
        }
    }

    private void sayln(String str) {
        if (this.quiet) {
            return;
        }
        System.err.println(str);
    }

    private void say(String str) {
        if (this.quiet) {
            return;
        }
        System.err.print(str);
    }

    public static long[] primeFactors(long j) {
        long[] jArr = new long[64];
        long abs = Math.abs(j);
        short s = 0;
        if (abs > 0) {
            while (abs % 2 == 0) {
                s = (short) (s + 1);
                jArr[s] = 2;
                abs /= 2;
            }
            while (abs % 3 == 0) {
                s = (short) (s + 1);
                jArr[s] = 3;
                abs /= 3;
            }
            for (int i = 5; i * i <= abs; i += 6) {
                for (int i2 = i; i2 <= i + 2; i2 += 2) {
                    while (abs % i2 == 0) {
                        s = (short) (s + 1);
                        jArr[s] = i2;
                        abs /= i2;
                    }
                }
            }
            if (abs > 1) {
                s = (short) (s + 1);
                jArr[s] = abs;
            }
        }
        jArr[0] = s;
        return jArr;
    }

    private static long getTestBatchSize(long j) {
        long j2 = 1;
        long[] primeFactors = primeFactors(j);
        long j3 = primeFactors[0];
        if (j3 == 0) {
            System.err.println("Attempt to test function on data of prime dimension.  This would involve a batchSize of 1 and may take a very long time.");
            System.exit(1);
        } else if (j3 == 2) {
            j2 = (int) primeFactors[1];
        } else {
            for (int i = 1; i < j3; i++) {
                j2 *= primeFactors[i];
            }
        }
        return j2;
    }

    public boolean testObjectiveFunction(Function function, double[] dArr, double d) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        Random random = new Random(System.currentTimeMillis());
        double[] dArr2 = new double[dArr.length];
        double[] dArr3 = new double[dArr.length];
        double[] dArr4 = new double[dArr.length];
        double[] dArr5 = new double[dArr.length];
        double[] dArr6 = new double[dArr.length];
        if (!(function instanceof DiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction abstractStochasticCachingDiffFunction = (AbstractStochasticCachingDiffFunction) function;
        int testBatchSize = (int) getTestBatchSize(abstractStochasticCachingDiffFunction.dataDimension());
        if (testBatchSize < 0 || testBatchSize > abstractStochasticCachingDiffFunction.dataDimension() || abstractStochasticCachingDiffFunction.dataDimension() % testBatchSize != 0) {
            System.err.println("Invalid testBatchSize found, testing aborted.  Data size: " + abstractStochasticCachingDiffFunction.dataDimension() + " batchSize: " + testBatchSize);
            System.exit(1);
        }
        int dataDimension = abstractStochasticCachingDiffFunction.dataDimension() / testBatchSize;
        sayln("Testing Gradients");
        sayln("data dimension  = " + abstractStochasticCachingDiffFunction.dataDimension());
        sayln("batch size = " + testBatchSize);
        sayln("number of batches = " + dataDimension);
        double[] dArr7 = new double[dArr.length];
        double[] dArr8 = new double[dArr.length];
        double d4 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            dArr7[i] = 0.0d;
            dArr2[i] = random.nextDouble();
        }
        for (int i2 = 0; i2 < dataDimension; i2++) {
            abstractStochasticCachingDiffFunction.method = StochasticCalculateMethods.IncorporatedFiniteDifference;
            d4 += abstractStochasticCachingDiffFunction.valueAt(dArr, dArr2, testBatchSize);
            abstractStochasticCachingDiffFunction.returnPreviousValues = true;
            System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr, dArr2, testBatchSize), 0, dArr8, 0, dArr8.length);
            abstractStochasticCachingDiffFunction.returnPreviousValues = true;
            System.arraycopy(abstractStochasticCachingDiffFunction.HdotVAt(dArr, dArr2, testBatchSize), 0, dArr3, 0, dArr3.length);
            dArr7 = ArrayMath.pairwiseAdd(dArr7, dArr8);
            abstractStochasticCachingDiffFunction.method = StochasticCalculateMethods.ExternalFiniteDifference;
            abstractStochasticCachingDiffFunction.recalculatePrevBatch = true;
            System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr, dArr2, testBatchSize), 0, dArr5, 0, dArr5.length);
            abstractStochasticCachingDiffFunction.recalculatePrevBatch = true;
            System.arraycopy(abstractStochasticCachingDiffFunction.HdotVAt(dArr, dArr2, dArr5, testBatchSize), 0, dArr4, 0, dArr4.length);
            double norm_inf = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(dArr5, dArr8));
            if (norm_inf > d2) {
                d2 = norm_inf;
            }
            double norm_inf2 = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(dArr3, dArr4));
            if (norm_inf2 > d3) {
                d3 = norm_inf2;
            }
        }
        System.arraycopy(abstractStochasticCachingDiffFunction.derivativeAt(dArr), 0, dArr6, 0, dArr6.length);
        double valueAt = abstractStochasticCachingDiffFunction.valueAt(dArr);
        if (ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(dArr6, dArr7)) < d) {
            sayln("");
            sayln("  Gradient is looking good");
        } else {
            double[] dArr9 = new double[dArr.length];
            double norm = ArrayMath.norm(ArrayMath.pairwiseSubtract(dArr7, dArr6));
            sayln("");
            sayln("  Seems there is a problem.  Gradient is off by norm of " + norm);
        }
        if (d2 < d) {
            sayln("");
            sayln("  Both gradients are the same");
        } else {
            double d5 = d4 - valueAt;
            sayln("");
            sayln("  Seems there is a problem.  The two methods of calculating the gradient are different  max |AD-FD|_inf Error of " + d2);
        }
        if (Math.abs(valueAt - d4) < d) {
            sayln("");
            sayln("  Value is looking good");
        } else {
            sayln("");
            sayln("  Seems there is a problem.  Value is off by " + (d4 - valueAt));
        }
        if (d3 < d) {
            sayln("");
            sayln("  Hv Approimations line up well");
            return true;
        }
        sayln("");
        sayln("    Seems there is a problem.  Hv approximations aren't quite close enough -- max |AD-FD|_inf Error of " + d3);
        return true;
    }

    public static void main(String[] strArr) {
        final double[] dArr = new double[500000];
        double[] dArr2 = new double[500000];
        for (int i = 0; i < 500000; i++) {
            dArr2[i] = ((i + 1) / 500000.0d) - 0.5d;
            dArr[i] = (5.0d * (i + 1)) / 500000.0d;
        }
        final double[] dArr3 = new double[500000];
        new SMDMinimizer().minimize(new DiffFunction() { // from class: edu.stanford.nlp.optimization.SMDMinimizer.1
            @Override // edu.stanford.nlp.optimization.DiffFunction
            public double[] derivativeAt(double[] dArr4) {
                double valuePow = 3.141592653589793d * valuePow(dArr4, 2.141592653589793d);
                for (int i2 = 0; i2 < 500000; i2++) {
                    dArr3[i2] = dArr4[i2] * dArr[i2] * valuePow;
                }
                return dArr3;
            }

            @Override // edu.stanford.nlp.optimization.Function
            public double valueAt(double[] dArr4) {
                return 1.0d + valuePow(dArr4, 3.141592653589793d);
            }

            private double valuePow(double[] dArr4, double d) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < 500000; i2++) {
                    d2 += dArr4[i2] * dArr4[i2] * dArr[i2];
                }
                return Math.pow(d2 * 0.5d, d);
            }

            @Override // edu.stanford.nlp.optimization.Function
            public int domainDimension() {
                return 500000;
            }
        }, 1.0E-4d, dArr2);
    }
}
