package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/optimization/QNMinimizer.class */
public class QNMinimizer implements Minimizer<DiffFunction> {
    private int k;
    private int M;
    private boolean histSet;
    private boolean quiet;
    private boolean outputIterationsToFile;
    PrintWriter file;
    private Function monitor;
    private FloatFunction floatMonitor;
    private static NumberFormat nf = new DecimalFormat("0.000E0");
    private List<double[]> sList;
    private List<double[]> yList;
    private List<Double> roList;
    private List<float[]> sList_float;
    private List<float[]> yList_float;
    private List<Float> roList_float;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/optimization/QNMinimizer$SurpriseConvergence.class */
    public static class SurpriseConvergence extends Throwable {
        public SurpriseConvergence(String str) {
            super(str);
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    public void setM(int i) {
        this.M = i;
    }

    public QNMinimizer(int i) {
        this.M = 0;
        this.histSet = false;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.file = null;
        this.sList = null;
        this.yList = null;
        this.roList = null;
        this.sList_float = new ArrayList();
        this.yList_float = new ArrayList();
        this.roList_float = new ArrayList();
        this.M = i;
    }

    public QNMinimizer() {
        this.M = 0;
        this.histSet = false;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.file = null;
        this.sList = null;
        this.yList = null;
        this.roList = null;
        this.sList_float = new ArrayList();
        this.yList_float = new ArrayList();
        this.roList_float = new ArrayList();
    }

    public QNMinimizer(Function function) {
        this.M = 0;
        this.histSet = false;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.file = null;
        this.sList = null;
        this.yList = null;
        this.roList = null;
        this.sList_float = new ArrayList();
        this.yList_float = new ArrayList();
        this.roList_float = new ArrayList();
        this.monitor = function;
    }

    public QNMinimizer(Function function, int i) {
        this.M = 0;
        this.histSet = false;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.file = null;
        this.sList = null;
        this.yList = null;
        this.roList = null;
        this.sList_float = new ArrayList();
        this.yList_float = new ArrayList();
        this.roList_float = new ArrayList();
        this.monitor = function;
        this.M = i;
    }

    public QNMinimizer(FloatFunction floatFunction) {
        this.M = 0;
        this.histSet = false;
        this.quiet = false;
        this.outputIterationsToFile = false;
        this.file = null;
        this.sList = null;
        this.yList = null;
        this.roList = null;
        this.sList_float = new ArrayList();
        this.yList_float = new ArrayList();
        this.roList_float = new ArrayList();
        this.floatMonitor = floatFunction;
    }

    private static double[] plusAndConstMult(double[] dArr, double[] dArr2, double d, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] + (d * dArr2[i]);
        }
        return dArr3;
    }

    public void setHistory(List<double[]> list, List<double[]> list2) {
        this.sList = list;
        this.yList = list2;
        this.histSet = true;
        this.roList = new ArrayList();
        for (int i = 0; i < this.sList.size(); i++) {
            this.roList.add(Double.valueOf(1.0d / ArrayMath.innerProduct(this.sList.get(i), this.yList.get(i))));
        }
    }

    private void computeDir(double[] dArr, double[] dArr2) throws SurpriseConvergence {
        System.arraycopy(dArr2, 0, dArr, 0, dArr2.length);
        int size = this.sList.size();
        double[] dArr3 = new double[size];
        for (int i = size - 1; i >= 0; i--) {
            dArr3[i] = this.roList.get(i).doubleValue() * ArrayMath.innerProduct(this.sList.get(i), dArr);
            plusAndConstMult(dArr, this.yList.get(i), -dArr3[i], dArr);
        }
        if (size != 0) {
            double[] dArr4 = this.yList.get(size - 1);
            double innerProduct = ArrayMath.innerProduct(dArr4, dArr4);
            if (innerProduct == 0.0d) {
                throw new SurpriseConvergence("Y is 0!!");
            }
            ArrayMath.multiplyInPlace(dArr, ArrayMath.innerProduct(this.sList.get(size - 1), dArr4) / innerProduct);
        }
        for (int i2 = 0; i2 < size; i2++) {
            plusAndConstMult(dArr, this.sList.get(i2), dArr3[i2] - (this.roList.get(i2).doubleValue() * ArrayMath.innerProduct(this.yList.get(i2), dArr)), dArr);
        }
        ArrayMath.multiplyInPlace(dArr, -1.0d);
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr) {
        return minimize(diffFunction, d, dArr, -1);
    }

    @Override // edu.stanford.nlp.optimization.Minimizer
    public double[] minimize(DiffFunction diffFunction, double d, double[] dArr, int i) {
        double[] dArr2;
        double[] dArr3;
        double lineSearch;
        int size;
        double doubleValue;
        say("QNMinimizer called on double function of " + diffFunction.domainDimension() + " variables;");
        if (this.M > 0) {
            sayln(" using M = " + this.M + ".");
        } else {
            sayln(" using dynamic setting of M.");
        }
        sayln("Iter: n <chooseDir> [(derivInDir) chooseNewPoint] newValue (relAvgImprovement)\n");
        LinkedList linkedList = new LinkedList();
        double[] dArr4 = dArr;
        double valueAt = diffFunction.valueAt(dArr4);
        if (this.monitor != null) {
            this.monitor.valueAt(dArr4);
        }
        if (this.outputIterationsToFile) {
            try {
                if (this.histSet) {
                    this.file = new PrintWriter((OutputStream) new FileOutputStream("QN_m_" + this.M + "_SET.output"), true);
                } else {
                    this.file = new PrintWriter((OutputStream) new FileOutputStream("QN_m_" + this.M + ".output"), true);
                }
            } catch (IOException e) {
                System.err.println("Caught IOException outputing QN data to file: " + e.getMessage());
                System.exit(1);
            }
        }
        double[] dArr5 = new double[dArr4.length];
        System.arraycopy(diffFunction.derivativeAt(dArr4), 0, dArr5, 0, dArr5.length);
        double[] dArr6 = new double[dArr4.length];
        double[] dArr7 = new double[dArr4.length];
        double[] dArr8 = new double[dArr4.length];
        if (this.sList == null) {
            this.sList = new ArrayList();
            this.yList = new ArrayList();
            this.roList = new ArrayList();
        }
        boolean z = i > 0;
        this.k = 0;
        while (true) {
            long currentTimeMillis = System.currentTimeMillis();
            try {
                say("Iter: " + this.k + " ");
                say("<");
                try {
                    computeDir(dArr8, dArr5);
                    say("> ");
                    if ((this.M <= 0 || this.sList.size() != this.M) && this.sList.size() != 20) {
                        dArr2 = new double[dArr4.length];
                        dArr3 = new double[dArr4.length];
                    } else {
                        dArr2 = this.sList.remove(0);
                        dArr3 = this.yList.remove(0);
                        this.roList.remove(0);
                    }
                    say("[");
                    lineSearch = lineSearch(diffFunction, dArr8, dArr4, dArr7, dArr5, valueAt);
                    say("] ");
                    System.arraycopy(diffFunction.derivativeAt(dArr7), 0, dArr6, 0, dArr6.length);
                    say(nf.format(lineSearch));
                    long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                    if (this.outputIterationsToFile) {
                        this.file.println(lineSearch + "," + currentTimeMillis2);
                    }
                    plusAndConstMult(dArr7, dArr4, -1.0d, dArr2);
                    plusAndConstMult(dArr6, dArr5, -1.0d, dArr3);
                    double innerProduct = 1.0d / ArrayMath.innerProduct(dArr2, dArr3);
                    this.sList.add(dArr2);
                    this.yList.add(dArr3);
                    this.roList.add(Double.valueOf(innerProduct));
                    linkedList.add(Double.valueOf(valueAt));
                    size = linkedList.size();
                    doubleValue = ((size == 10 ? (Double) linkedList.remove() : (Double) linkedList.peek()).doubleValue() - lineSearch) / size;
                    sayln(" (" + nf.format(doubleValue / lineSearch) + ")");
                } catch (SurpriseConvergence e2) {
                    System.err.println("surprise!");
                    clearStuff();
                    return dArr4;
                }
            } catch (OutOfMemoryError e3) {
                sayln(" --- Reached memory limit.  Setting m and redoing iteration...");
                this.M = this.sList.size();
                this.k--;
            }
            if ((size <= 5 || doubleValue / lineSearch >= d) && (!z || this.k < i)) {
                if (this.monitor != null) {
                    this.monitor.valueAt(dArr7);
                }
                valueAt = lineSearch;
                double[] dArr9 = dArr4;
                dArr4 = dArr7;
                dArr7 = dArr9;
                System.arraycopy(dArr6, 0, dArr5, 0, dArr6.length);
                if (this.quiet) {
                    System.err.print(".");
                }
                this.k++;
            }
        }
        if (this.outputIterationsToFile) {
            this.file.close();
        }
        System.err.println("no improvement: " + doubleValue);
        clearStuff();
        return dArr7;
    }

    private double lineSearchInterpolated(DiffFunction diffFunction, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double d) {
        double innerProduct;
        double innerProduct2 = ArrayMath.innerProduct(dArr, dArr4);
        double d2 = d;
        if (innerProduct2 > 0.0d) {
            System.err.print(" !! Direction of ascent !!");
        }
        double d3 = this.k <= 2 ? 0.1d : 1.0d;
        do {
            double d4 = d2;
            d2 = diffFunction.valueAt(plusAndConstMult(dArr2, dArr, d3, dArr3));
            innerProduct = ArrayMath.innerProduct(diffFunction.derivativeAt(dArr3), dArr);
            if (innerProduct < innerProduct2) {
                System.err.println("Warning: Function is showing non convex behavior.");
                System.err.println("         Taking a small step and hoping things change");
                return 1.0E-4d;
            }
            double d5 = (-innerProduct2) / (2.0d * ((d2 - (innerProduct2 * d3)) / (d3 * d3)));
            if (Math.abs(d5 - d3) < 1.0E-4d) {
                System.err.println("!");
            }
            d3 = d5;
            if (d2 > d4 + (0.01d * d3 * innerProduct2)) {
                return 0.1d;
            }
        } while (Math.abs(innerProduct) <= 0.99d * Math.abs(innerProduct2));
        return 0.1d;
    }

    private double lineSearch(Function function, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double d) {
        double d2;
        double d3;
        double innerProduct = ArrayMath.innerProduct(dArr, dArr4);
        say("(" + nf.format(innerProduct) + ")");
        if (innerProduct > 0.0d) {
            say("{WARNING--- direction of positive gradient chosen!}");
        }
        if (this.k <= 2) {
            d2 = 0.1d;
            d3 = 0.1d;
        } else {
            d2 = 1.0d;
            d3 = 0.1d;
        }
        double d4 = 0.01d * innerProduct;
        while (true) {
            double valueAt = function.valueAt(plusAndConstMult(dArr2, dArr, d2, dArr3));
            if (valueAt <= d + (d4 * d2)) {
                return valueAt;
            }
            if (valueAt < d) {
                say("!");
            } else {
                say(".");
            }
            d2 = d3 * d2;
        }
    }

    private static float[] plusAndConstMult(float[] fArr, float[] fArr2, float f, float[] fArr3) {
        for (int i = 0; i < fArr.length; i++) {
            fArr3[i] = fArr[i] + (f * fArr2[i]);
        }
        return fArr3;
    }

    private void computeDir(float[] fArr, float[] fArr2) throws SurpriseConvergence {
        System.arraycopy(fArr2, 0, fArr, 0, fArr2.length);
        int size = this.sList_float.size();
        float[] fArr3 = new float[size];
        for (int i = size - 1; i >= 0; i--) {
            fArr3[i] = this.roList_float.get(i).floatValue() * ((float) ArrayMath.innerProduct(this.sList_float.get(i), fArr));
            plusAndConstMult(fArr, this.yList_float.get(i), -fArr3[i], fArr);
        }
        if (size != 0) {
            float[] fArr4 = this.yList_float.get(size - 1);
            if (((float) ArrayMath.innerProduct(fArr4, fArr4)) == 0.0f) {
                throw new SurpriseConvergence("Y is 0!!");
            }
            ArrayMath.multiplyInPlace(fArr, ((float) ArrayMath.innerProduct(this.sList_float.get(size - 1), fArr4)) / r0);
        }
        for (int i2 = 0; i2 < size; i2++) {
            plusAndConstMult(fArr, this.sList_float.get(i2), fArr3[i2] - (this.roList_float.get(i2).floatValue() * ((float) ArrayMath.innerProduct(this.yList_float.get(i2), fArr))), fArr);
        }
        ArrayMath.multiplyInPlace(fArr, -1.0d);
    }

    public float[] minimize(FloatFunction floatFunction, float f, float[] fArr) {
        float[] fArr2;
        float[] fArr3;
        float lineSearch;
        int size;
        float floatValue;
        say("QNMinimizer called on float function of " + floatFunction.domainDimension() + " variables;");
        if (this.M > 0) {
            sayln(" Using m = " + this.M);
        } else {
            sayln(" Using dynamic setting of M.");
        }
        if (!(floatFunction instanceof DiffFloatFunction)) {
            throw new UnsupportedOperationException();
        }
        DiffFloatFunction diffFloatFunction = (DiffFloatFunction) floatFunction;
        LinkedList linkedList = new LinkedList();
        float[] fArr4 = fArr;
        float valueAt = diffFloatFunction.valueAt(fArr4);
        if (this.monitor != null) {
            this.floatMonitor.valueAt(fArr4);
        }
        float[] fArr5 = new float[fArr4.length];
        System.arraycopy(diffFloatFunction.derivativeAt(fArr4), 0, fArr5, 0, fArr5.length);
        float[] fArr6 = new float[fArr4.length];
        float[] fArr7 = new float[fArr4.length];
        float[] fArr8 = new float[fArr4.length];
        this.sList_float = new ArrayList();
        this.yList_float = new ArrayList();
        this.roList_float = new ArrayList();
        sayln("Iter: n <chooseDir> [(derivInDir) chooseNewPoint] newValue (relAvgImprovement)\n");
        this.k = 0;
        while (true) {
            try {
                say("Iter: " + this.k + " ");
                say("<");
                try {
                    computeDir(fArr8, fArr5);
                    say("> ");
                    if ((this.M <= 0 || this.sList_float.size() != this.M) && this.sList_float.size() != 20) {
                        fArr2 = new float[fArr4.length];
                        fArr3 = new float[fArr4.length];
                    } else {
                        fArr2 = this.sList_float.remove(0);
                        fArr3 = this.yList_float.remove(0);
                        this.roList_float.remove(0);
                    }
                    say("[");
                    lineSearch = lineSearch(diffFloatFunction, fArr8, fArr4, fArr7, fArr5, valueAt);
                    say("] ");
                    System.arraycopy(diffFloatFunction.derivativeAt(fArr7), 0, fArr6, 0, fArr6.length);
                    say(nf.format(lineSearch));
                    plusAndConstMult(fArr7, fArr4, -1.0f, fArr2);
                    plusAndConstMult(fArr6, fArr5, -1.0f, fArr3);
                    float innerProduct = (float) (1.0d / ArrayMath.innerProduct(fArr2, fArr3));
                    this.sList_float.add(fArr2);
                    this.yList_float.add(fArr3);
                    this.roList_float.add(Float.valueOf(innerProduct));
                    linkedList.add(Float.valueOf(valueAt));
                    size = linkedList.size();
                    floatValue = ((size == 10 ? (Float) linkedList.remove() : (Float) linkedList.peek()).floatValue() - lineSearch) / size;
                    sayln(" (" + nf.format(floatValue / lineSearch) + ")");
                } catch (SurpriseConvergence e) {
                    clearStuff();
                    return fArr4;
                }
            } catch (OutOfMemoryError e2) {
                sayln(" --- Reached memory limit.  Setting m and redoing iteration...");
                this.M = this.sList_float.size();
                this.k--;
            }
            if (size > 5 && floatValue / lineSearch < f) {
                clearStuff();
                return fArr7;
            }
            if (this.monitor != null) {
                this.floatMonitor.valueAt(fArr7);
            }
            valueAt = lineSearch;
            float[] fArr9 = fArr4;
            fArr4 = fArr7;
            fArr7 = fArr9;
            System.arraycopy(fArr6, 0, fArr5, 0, fArr6.length);
            if (this.quiet) {
                System.err.print(".");
            }
            this.k++;
        }
    }

    private float lineSearch(FloatFunction floatFunction, float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4, float f) {
        float f2;
        float f3;
        float innerProduct = (float) ArrayMath.innerProduct(fArr, fArr4);
        say("(" + nf.format(innerProduct) + ")");
        if (innerProduct > 0.0f) {
            say("{WARNING--- direction of positive gradient chosen!}");
        }
        if (this.k <= 2) {
            f2 = 0.1f;
            f3 = 0.1f;
        } else {
            f2 = 1.0f;
            f3 = 0.5f;
        }
        float f4 = 0.01f * innerProduct;
        while (true) {
            float valueAt = floatFunction.valueAt(plusAndConstMult(fArr2, fArr, f2, fArr3));
            if (valueAt <= f + (f4 * f2)) {
                return valueAt;
            }
            if (valueAt < f) {
                say("!");
            } else {
                say(".");
            }
            f2 = f3 * f2;
        }
    }

    private void clearStuff() {
        this.sList = null;
        this.yList = null;
        this.roList = null;
        this.sList_float = null;
        this.yList_float = null;
        this.roList_float = null;
    }

    private void sayln(String str) {
        if (this.quiet) {
            return;
        }
        System.err.println(str);
    }

    private void say(String str) {
        if (this.quiet) {
            return;
        }
        System.err.print(str);
    }

    public static void main(String[] strArr) {
        final double[] dArr = new double[500000];
        final float[] fArr = new float[500000];
        double[] dArr2 = new double[500000];
        for (int i = 0; i < 500000; i++) {
            dArr2[i] = ((i + 1) / 500000.0d) - 0.5d;
            dArr[i] = (5.0d * (i + 1)) / 500000.0d;
            fArr[i] = (float) dArr[i];
        }
        float[] doubleArrayToFloatArray = ArrayMath.doubleArrayToFloatArray(dArr2);
        final double[] dArr3 = new double[500000];
        final float[] fArr2 = new float[500000];
        DiffFunction diffFunction = new DiffFunction() { // from class: edu.stanford.nlp.optimization.QNMinimizer.1
            @Override // edu.stanford.nlp.optimization.DiffFunction
            public double[] derivativeAt(double[] dArr4) {
                double valuePow = 3.141592653589793d * valuePow(dArr4, 2.141592653589793d);
                for (int i2 = 0; i2 < 500000; i2++) {
                    dArr3[i2] = dArr4[i2] * dArr[i2] * valuePow;
                }
                return dArr3;
            }

            @Override // edu.stanford.nlp.optimization.Function
            public double valueAt(double[] dArr4) {
                return 1.0d + valuePow(dArr4, 3.141592653589793d);
            }

            private double valuePow(double[] dArr4, double d) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < 500000; i2++) {
                    d2 += dArr4[i2] * dArr4[i2] * dArr[i2];
                }
                return Math.pow(d2 * 0.5d, d);
            }

            @Override // edu.stanford.nlp.optimization.Function
            public int domainDimension() {
                return 500000;
            }
        };
        DiffFloatFunction diffFloatFunction = new DiffFloatFunction() { // from class: edu.stanford.nlp.optimization.QNMinimizer.2
            @Override // edu.stanford.nlp.optimization.DiffFloatFunction
            public float[] derivativeAt(float[] fArr3) {
                float valuePow = 3.1415927f * valuePow(fArr3, 2.141592653589793d);
                for (int i2 = 0; i2 < 500000; i2++) {
                    fArr2[i2] = fArr3[i2] * fArr[i2] * valuePow;
                }
                return fArr2;
            }

            @Override // edu.stanford.nlp.optimization.FloatFunction
            public float valueAt(float[] fArr3) {
                return 1.0f + valuePow(fArr3, 3.141592653589793d);
            }

            private float valuePow(float[] fArr3, double d) {
                float f = 0.0f;
                for (int i2 = 0; i2 < 500000; i2++) {
                    f += fArr3[i2] * fArr3[i2] * fArr[i2];
                }
                return (float) Math.pow(f * 0.5d, d);
            }

            @Override // edu.stanford.nlp.optimization.FloatFunction
            public int domainDimension() {
                return 500000;
            }
        };
        QNMinimizer qNMinimizer = new QNMinimizer();
        System.out.println("-------------------------");
        System.out.println("-----               -----");
        System.out.println("-----    DOUBLE     -----");
        System.out.println("-----               -----");
        System.out.println("-------------------------");
        System.out.println();
        qNMinimizer.minimize(diffFunction, 1.0E-4d, dArr2);
        System.out.println("-------------------------");
        System.out.println("-----               -----");
        System.out.println("-----     FLOAT     -----");
        System.out.println("-----               -----");
        System.out.println("-------------------------");
        System.out.println();
        qNMinimizer.setM(0);
        qNMinimizer.minimize(diffFloatFunction, 1.0E-4f, doubleArrayToFloatArray);
    }
}
