/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.StochasticMinimizer;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.List;

public class SQNMinimizer<T extends Function>
extends StochasticMinimizer<T> {
    private static Redwood.RedwoodChannels log = Redwood.channels(SQNMinimizer.class);
    private int M = 0;
    private double lambda = 1.0;
    private double cPosDef = 1.0;
    private double epsilon = 1.0E-10;
    private List<double[]> sList = new ArrayList<double[]>();
    private List<double[]> yList = new ArrayList<double[]>();
    private List<Double> roList = new ArrayList<Double>();
    double[] dir;
    double[] s;
    double[] y;
    double ro;

    public void setM(int m) {
        this.M = m;
    }

    public SQNMinimizer(int m) {
        this.M = m;
    }

    public SQNMinimizer() {
    }

    public SQNMinimizer(int mem, double initialGain, int batchSize, boolean output) {
        this.gain = initialGain;
        this.bSize = batchSize;
        this.M = mem;
        this.outputIterationsToFile = output;
    }

    @Override
    public String getName() {
        int g = (int)(this.gain * 1000.0);
        return "SQN" + this.bSize + "_g" + g;
    }

    private static double[] plusAndConstMult(double[] a, double[] b, double c, double[] d) {
        for (int i = 0; i < a.length; ++i) {
            d[i] = a[i] + c * b[i];
        }
        return d;
    }

    @Override
    public Pair<Integer, Double> tune(Function function, double[] initial, long msPerTest) {
        log.info("No tuning set yet");
        return new Pair<Integer, Double>(this.bSize, this.gain);
    }

    private void computeDir(double[] dir, double[] fg) throws SurpriseConvergence {
        System.arraycopy(fg, 0, dir, 0, fg.length);
        int mmm = this.sList.size();
        double[] as = new double[mmm];
        double[] factors = new double[dir.length];
        for (int i = mmm - 1; i >= 0; --i) {
            as[i] = this.roList.get(i) * ArrayMath.innerProduct(this.sList.get(i), dir);
            SQNMinimizer.plusAndConstMult(dir, this.yList.get(i), -as[i], dir);
        }
        if (mmm != 0) {
            double[] y = this.yList.get(mmm - 1);
            double yDotY = ArrayMath.innerProduct(y, y);
            if (yDotY == 0.0) {
                throw new SurpriseConvergence("Y is 0!!");
            }
            double gamma = ArrayMath.innerProduct(this.sList.get(mmm - 1), y) / yDotY;
            ArrayMath.multiplyInPlace(dir, gamma);
        } else if (mmm == 0) {
            ArrayMath.multiplyInPlace(dir, this.epsilon);
        }
        for (int i = 0; i < mmm; ++i) {
            double b = this.roList.get(i) * ArrayMath.innerProduct(this.yList.get(i), dir);
            SQNMinimizer.plusAndConstMult(dir, this.sList.get(i), this.cPosDef * as[i] - b, dir);
            SQNMinimizer.plusAndConstMult(ArrayMath.pairwiseMultiply(this.yList.get(i), this.sList.get(i)), factors, 1.0, factors);
        }
        ArrayMath.multiplyInPlace(dir, -1.0);
    }

    @Override
    protected void init(AbstractStochasticCachingDiffFunction func) {
        this.sList = new ArrayList<double[]>();
        this.yList = new ArrayList<double[]>();
        this.dir = new double[func.domainDimension()];
    }

    @Override
    protected void takeStep(AbstractStochasticCachingDiffFunction dfunction) {
        int i;
        try {
            this.computeDir(this.dir, this.newGrad);
        }
        catch (SurpriseConvergence s) {
            this.clearStuff();
        }
        double thisGain = this.gain * SQNMinimizer.gainSchedule(this.k, 5 * this.numBatches);
        for (i = 0; i < this.x.length; ++i) {
            this.newX[i] = this.x[i] + thisGain * this.dir[i];
        }
        this.say(" A ");
        if (this.M > 0 && this.sList.size() == this.M || this.sList.size() == this.M) {
            this.s = this.sList.remove(0);
            this.y = this.yList.remove(0);
        } else {
            this.s = new double[this.x.length];
            this.y = new double[this.x.length];
        }
        dfunction.recalculatePrevBatch = true;
        System.arraycopy(dfunction.derivativeAt(this.newX, this.bSize), 0, this.y, 0, this.grad.length);
        this.ro = 0.0;
        for (i = 0; i < this.x.length; ++i) {
            this.s[i] = this.newX[i] - this.x[i];
            this.y[i] = this.y[i] - this.newGrad[i] + this.lambda * this.s[i];
            this.ro += this.s[i] * this.y[i];
        }
        this.ro = 1.0 / this.ro;
        this.sList.add(this.s);
        this.yList.add(this.y);
        this.roList.add(this.ro);
    }

    private void clearStuff() {
        this.sList = null;
        this.yList = null;
        this.roList = null;
    }

    private static class SurpriseConvergence
    extends Throwable {
        private static final long serialVersionUID = -4377976289620760327L;

        public SurpriseConvergence(String s) {
            super(s);
        }
    }
}

