/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.learning;

import edu.stanford.nlp.loglinear.learning.AbstractDifferentiableFunction;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public abstract class AbstractBatchOptimizer {
    private static Redwood.RedwoodChannels log = Redwood.channels(AbstractBatchOptimizer.class);
    List<Constraint> constraints = new ArrayList<Constraint>();

    public <T> ConcatVector optimize(T[] dataset, AbstractDifferentiableFunction<T> fn) {
        return this.optimize(dataset, fn, new ConcatVector(0), 0.0, 1.0E-5, false);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public <T> ConcatVector optimize(T[] dataset, AbstractDifferentiableFunction<T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, boolean quiet) {
        if (!quiet) {
            log.info("\n**************\nBeginning training\n");
        } else {
            log.info("[Beginning quiet training]");
        }
        TrainingWorker<T> mainWorker = new TrainingWorker<T>(dataset, fn, initialWeights, l2regularization, convergenceDerivativeNorm, quiet);
        new Thread(mainWorker).start();
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        if (!quiet) {
            log.info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early.");
            log.info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while");
            log.info("if left to their own devices.\n");
            while (true) {
                if (mainWorker.isFinished) {
                    log.info("training completed without interruption");
                    return mainWorker.weights;
                }
                try {
                    if (!br.ready()) continue;
                    log.info("received quit command: quitting");
                    log.info("training completed by interruption");
                    mainWorker.isFinished = true;
                    return mainWorker.weights;
                }
                catch (IOException e) {
                    e.printStackTrace();
                    continue;
                }
                break;
            }
        }
        while (!mainWorker.isFinished) {
            Object object = mainWorker.naturalTerminationBarrier;
            synchronized (object) {
                try {
                    mainWorker.naturalTerminationBarrier.wait();
                }
                catch (InterruptedException e) {
                    throw new RuntimeInterruptedException(e);
                }
            }
        }
        log.info("[Quiet training complete]");
        return mainWorker.weights;
    }

    public void addSparseConstraint(int component, int index, double value) {
        this.constraints.add(new Constraint(component, index, value));
    }

    public void addDenseConstraint(int component, double[] arr) {
        this.constraints.add(new Constraint(component, arr));
    }

    public abstract boolean updateWeights(ConcatVector var1, ConcatVector var2, double var3, OptimizationState var5, boolean var6);

    protected abstract OptimizationState getFreshOptimizationState(ConcatVector var1);

    private class TrainingWorker<T>
    implements Runnable {
        ConcatVector weights;
        OptimizationState optimizationState;
        boolean isFinished = false;
        boolean useThreads = Runtime.getRuntime().availableProcessors() > 1;
        T[] dataset;
        AbstractDifferentiableFunction<T> fn;
        double l2regularization;
        double convergenceDerivativeNorm;
        boolean quiet;
        final Object naturalTerminationBarrier = new Object();

        public TrainingWorker(T[] dataset, AbstractDifferentiableFunction<T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, boolean quiet) {
            this.optimizationState = AbstractBatchOptimizer.this.getFreshOptimizationState(initialWeights);
            this.weights = initialWeights.deepClone();
            this.dataset = dataset;
            this.fn = fn;
            this.l2regularization = l2regularization;
            this.convergenceDerivativeNorm = convergenceDerivativeNorm;
            this.quiet = quiet;
        }

        private int estimateRelativeRuntime(T datum) {
            if (datum instanceof GraphicalModel) {
                int cost = 0;
                GraphicalModel model = (GraphicalModel)datum;
                for (GraphicalModel.Factor f : model.factors) {
                    cost += f.featuresTable.combinatorialNeighborStatesCount();
                }
                return cost;
            }
            return 1;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors());
            List[] queues = new List[numThreads];
            Random r = new Random();
            if (this.useThreads) {
                for (int i = 0; i < numThreads; ++i) {
                    queues[i] = new ArrayList();
                }
                int[] queueEstimatedTotalCost = new int[numThreads];
                for (T datum : this.dataset) {
                    int datumEstimatedCost = this.estimateRelativeRuntime(datum);
                    int minCostQueue = 0;
                    for (int i = 0; i < numThreads; ++i) {
                        if (queueEstimatedTotalCost[i] >= queueEstimatedTotalCost[minCostQueue]) continue;
                        minCostQueue = i;
                    }
                    int n = minCostQueue;
                    queueEstimatedTotalCost[n] = queueEstimatedTotalCost[n] + datumEstimatedCost;
                    queues[minCostQueue].add(datum);
                }
            }
            while (!this.isFinished) {
                long startTime = System.currentTimeMillis();
                long threadWaiting = 0L;
                ConcatVector derivative = this.weights.newEmptyClone();
                double logLikelihood = 0.0;
                if (this.useThreads) {
                    GradientWorker<T>[] workers = new GradientWorker[numThreads];
                    Thread[] threads = new Thread[numThreads];
                    for (int i = 0; i < workers.length; ++i) {
                        workers[i] = new GradientWorker<T>(this, i, numThreads, queues[i], this.fn, this.weights);
                        threads[i] = new Thread(workers[i]);
                        workers[i].jvmThreadId = threads[i].getId();
                        threads[i].start();
                    }
                    long minFinishTime = Long.MAX_VALUE;
                    long maxFinishTime = Long.MIN_VALUE;
                    long minCPUTime = Long.MAX_VALUE;
                    long maxCPUTime = Long.MIN_VALUE;
                    int slowestWorker = 0;
                    int fastestWorker = 0;
                    for (int i = 0; i < workers.length; ++i) {
                        try {
                            threads[i].join();
                        }
                        catch (InterruptedException e) {
                            throw new RuntimeInterruptedException(e);
                        }
                        logLikelihood += workers[i].localLogLikelihood;
                        derivative.addVectorInPlace(workers[i].localDerivative, 1.0);
                        if (workers[i].finishedAtTime < minFinishTime) {
                            minFinishTime = workers[i].finishedAtTime;
                        }
                        if (workers[i].finishedAtTime > maxFinishTime) {
                            maxFinishTime = workers[i].finishedAtTime;
                        }
                        if (workers[i].cpuTimeRequired < minCPUTime) {
                            fastestWorker = i;
                            minCPUTime = workers[i].cpuTimeRequired;
                        }
                        if (workers[i].cpuTimeRequired <= maxCPUTime) continue;
                        slowestWorker = i;
                        maxCPUTime = workers[i].cpuTimeRequired;
                    }
                    threadWaiting = maxFinishTime - minFinishTime;
                    double waitingPercentage = (double)(maxCPUTime - minCPUTime) / (double)maxCPUTime;
                    int needTransferItems = (int)Math.floor((double)queues[slowestWorker].size() * waitingPercentage * 0.5);
                    for (int i = 0; i < needTransferItems; ++i) {
                        int toTransfer = r.nextInt(queues[slowestWorker].size());
                        Object datum = queues[slowestWorker].get(toTransfer);
                        queues[slowestWorker].remove(toTransfer);
                        queues[fastestWorker].add(datum);
                    }
                    if (this.isFinished) {
                        return;
                    }
                } else {
                    for (GradientWorker datum : this.dataset) {
                        assert (datum != null);
                        logLikelihood += this.fn.getSummaryForInstance(datum, this.weights, derivative);
                        if (!this.isFinished) continue;
                        return;
                    }
                }
                logLikelihood /= (double)this.dataset.length;
                derivative.mapInPlace(d -> d / (double)this.dataset.length);
                long gradientComputationTime = System.currentTimeMillis() - startTime;
                logLikelihood -= this.l2regularization * this.weights.dotProduct(this.weights);
                derivative.addVectorInPlace(this.weights, -2.0 * this.l2regularization);
                for (Constraint constraint : AbstractBatchOptimizer.this.constraints) {
                    constraint.applyToDerivative(derivative);
                }
                double derivativeNorm = derivative.dotProduct(derivative);
                if (derivativeNorm < this.convergenceDerivativeNorm) {
                    if (this.quiet) break;
                    log.info("Derivative norm " + derivativeNorm + " < " + this.convergenceDerivativeNorm + ": quitting");
                    break;
                }
                if (!this.quiet) {
                    log.info("[" + gradientComputationTime + " ms, threads waiting " + threadWaiting + " ms]");
                }
                boolean converged = AbstractBatchOptimizer.this.updateWeights(this.weights, derivative, logLikelihood, this.optimizationState, this.quiet);
                for (Constraint constraint : AbstractBatchOptimizer.this.constraints) {
                    constraint.applyToWeights(this.weights);
                }
                if (!converged) continue;
                break;
            }
            Object object = this.naturalTerminationBarrier;
            synchronized (object) {
                this.naturalTerminationBarrier.notifyAll();
            }
            this.isFinished = true;
        }
    }

    private static class GradientWorker<T>
    implements Runnable {
        ConcatVector localDerivative;
        double localLogLikelihood = 0.0;
        TrainingWorker mainWorker;
        int threadIdx;
        int numThreads;
        List<T> queue;
        AbstractDifferentiableFunction<T> fn;
        ConcatVector weights;
        long jvmThreadId = 0L;
        long finishedAtTime = 0L;
        long cpuTimeRequired = 0L;

        public GradientWorker(TrainingWorker<T> mainWorker, int threadIdx, int numThreads, List<T> queue, AbstractDifferentiableFunction<T> fn, ConcatVector weights) {
            this.mainWorker = mainWorker;
            this.threadIdx = threadIdx;
            this.numThreads = numThreads;
            this.queue = queue;
            this.fn = fn;
            this.weights = weights;
            this.localDerivative = weights.newEmptyClone();
        }

        @Override
        public void run() {
            long startTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(this.jvmThreadId);
            for (T datum : this.queue) {
                this.localLogLikelihood += this.fn.getSummaryForInstance(datum, this.weights, this.localDerivative);
                if (!this.mainWorker.isFinished) continue;
                return;
            }
            this.finishedAtTime = System.currentTimeMillis();
            long endTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(this.jvmThreadId);
            this.cpuTimeRequired = endTime - startTime;
        }
    }

    protected abstract class OptimizationState {
        protected OptimizationState() {
        }
    }

    private static class Constraint {
        int component;
        boolean isSparse;
        int index;
        double value;
        double[] arr;

        public Constraint(int component, int index, double value) {
            this.isSparse = true;
            this.component = component;
            this.index = index;
            this.value = value;
        }

        public Constraint(int component, double[] arr) {
            this.isSparse = false;
            this.component = component;
            this.arr = arr;
        }

        public void applyToWeights(ConcatVector weights) {
            if (this.isSparse) {
                weights.setSparseComponent(this.component, this.index, this.value);
            } else {
                weights.setDenseComponent(this.component, this.arr);
            }
        }

        public void applyToDerivative(ConcatVector derivative) {
            if (this.isSparse) {
                derivative.setSparseComponent(this.component, this.index, 0.0);
            } else {
                derivative.setDenseComponent(this.component, new double[]{0.0});
            }
        }
    }
}

