/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.lda.cvb.TopicModel;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelTrainer {
    private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class);
    private final int numTopics;
    private final int numTerms;
    private TopicModel readModel;
    private TopicModel writeModel;
    private ThreadPoolExecutor threadPool;
    private BlockingQueue<Runnable> workQueue;
    private final int numTrainThreads;
    private final boolean isReadWrite;

    public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel, int numTrainThreads, int numTopics, int numTerms) {
        this.readModel = initialReadModel;
        this.writeModel = initialWriteModel;
        this.numTrainThreads = numTrainThreads;
        this.numTopics = numTopics;
        this.numTerms = numTerms;
        this.isReadWrite = initialReadModel == initialWriteModel;
    }

    public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) {
        this(model, model, numTrainThreads, numTopics, numTerms);
    }

    public TopicModel getReadModel() {
        return this.readModel;
    }

    public void start() {
        log.info("Starting training threadpool with {} threads", (Object)this.numTrainThreads);
        this.workQueue = new ArrayBlockingQueue<Runnable>(this.numTrainThreads * 10);
        this.threadPool = new ThreadPoolExecutor(this.numTrainThreads, this.numTrainThreads, 0L, TimeUnit.SECONDS, this.workQueue);
        this.threadPool.allowCoreThreadTimeOut(false);
        this.threadPool.prestartAllCoreThreads();
        this.writeModel.reset();
    }

    public void train(VectorIterable matrix, VectorIterable docTopicCounts) {
        this.train(matrix, docTopicCounts, 1);
    }

    public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) {
        return this.calculatePerplexity(matrix, docTopicCounts, 0.0);
    }

    public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts, double testFraction) {
        Iterator docIterator = matrix.iterator();
        Iterator docTopicIterator = docTopicCounts.iterator();
        double perplexity = 0.0;
        double matrixNorm = 0.0;
        while (docIterator.hasNext() && docTopicIterator.hasNext()) {
            MatrixSlice docSlice = (MatrixSlice)docIterator.next();
            MatrixSlice topicSlice = (MatrixSlice)docTopicIterator.next();
            int docId = docSlice.index();
            Vector document = docSlice.vector();
            Vector topicDist = topicSlice.vector();
            if (testFraction != 0.0 && (double)docId % (1.0 / testFraction) != 0.0) continue;
            this.trainSync(document, topicDist, false, 10);
            perplexity += this.readModel.perplexity(document, topicDist);
            matrixNorm += document.norm(1.0);
        }
        return perplexity / matrixNorm;
    }

    public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) {
        this.start();
        Iterator docIterator = matrix.iterator();
        Iterator docTopicIterator = docTopicCounts.iterator();
        long startTime = System.nanoTime();
        int i = 0;
        double[] times = new double[100];
        HashMap<Vector, Vector> batch = Maps.newHashMap();
        int numTokensInBatch = 0;
        long batchStart = System.nanoTime();
        while (docIterator.hasNext() && docTopicIterator.hasNext()) {
            ++i;
            Vector document = ((MatrixSlice)docIterator.next()).vector();
            Vector topicDist = ((MatrixSlice)docTopicIterator.next()).vector();
            if (this.isReadWrite) {
                if (batch.size() < this.numTrainThreads) {
                    batch.put(document, topicDist);
                    if (!log.isDebugEnabled()) continue;
                    numTokensInBatch += document.getNumNondefaultElements();
                    continue;
                }
                this.batchTrain(batch, true, numDocTopicIters);
                long time = System.nanoTime();
                log.debug("trained {} docs with {} tokens, start time {}, end time {}", this.numTrainThreads, numTokensInBatch, batchStart, time);
                batchStart = time;
                numTokensInBatch = 0;
                continue;
            }
            long start = System.nanoTime();
            this.train(document, topicDist, true, numDocTopicIters);
            if (!log.isDebugEnabled()) continue;
            times[i % times.length] = (double)(System.nanoTime() - start) / (1000000.0 * (double)document.getNumNondefaultElements());
            if (i % 100 != 0) continue;
            long time = System.nanoTime() - startTime;
            log.debug("trained {} documents in {}ms", (Object)i, (Object)((double)time / 1000000.0));
            if (i % 500 != 0) continue;
            Arrays.sort(times);
            log.debug("training took median {}ms per token-instance", (Object)times[times.length / 2]);
        }
        this.stop();
    }

    public void batchTrain(Map<Vector, Vector> batch, boolean update, int numDocTopicsIters) {
        while (true) {
            try {
                ArrayList<TrainerRunnable> runnables = Lists.newArrayList();
                for (Map.Entry<Vector, Vector> entry : batch.entrySet()) {
                    runnables.add(new TrainerRunnable(this.readModel, null, entry.getKey(), entry.getValue(), new SparseRowMatrix(this.numTopics, this.numTerms, true), numDocTopicsIters));
                }
                this.threadPool.invokeAll(runnables);
                if (!update) break;
                for (TrainerRunnable runnable : runnables) {
                    this.writeModel.update(runnable.docTopicModel);
                }
            }
            catch (InterruptedException e) {
                log.warn("Interrupted during batch training, retrying!", e);
                continue;
            }
            break;
        }
    }

    public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
        while (true) {
            try {
                this.workQueue.put(new TrainerRunnable(this.readModel, update ? this.writeModel : null, document, docTopicCounts, new SparseRowMatrix(this.numTopics, this.numTerms, true), numDocTopicIters));
                return;
            }
            catch (InterruptedException e) {
                log.warn("Interrupted waiting to submit document to work queue: {}", (Object)document, (Object)e);
                continue;
            }
            break;
        }
    }

    public void trainSync(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
        new TrainerRunnable(this.readModel, update ? this.writeModel : null, document, docTopicCounts, new SparseRowMatrix(this.numTopics, this.numTerms, true), numDocTopicIters).run();
    }

    public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) {
        TrainerRunnable runner = new TrainerRunnable(this.readModel, null, document, docTopicCounts, new SparseRowMatrix(this.numTopics, this.numTerms, true), numDocTopicIters);
        return runner.call();
    }

    public void stop() {
        long startTime = System.nanoTime();
        log.info("Initiating stopping of training threadpool");
        try {
            this.threadPool.shutdown();
            if (!this.threadPool.awaitTermination(60L, TimeUnit.SECONDS)) {
                log.warn("Threadpool timed out on await termination - jobs still running!");
            }
            long newTime = System.nanoTime();
            log.info("threadpool took: {}ms", (Object)((double)(newTime - startTime) / 1000000.0));
            startTime = newTime;
            this.readModel.stop();
            newTime = System.nanoTime();
            log.info("readModel.stop() took {}ms", (Object)((double)(newTime - startTime) / 1000000.0));
            startTime = newTime;
            this.writeModel.stop();
            newTime = System.nanoTime();
            log.info("writeModel.stop() took {}ms", (Object)((double)(newTime - startTime) / 1000000.0));
            TopicModel tmpModel = this.writeModel;
            this.writeModel = this.readModel;
            this.readModel = tmpModel;
        }
        catch (InterruptedException e) {
            log.error("Interrupted shutting down!", e);
        }
    }

    public void persist(Path outputPath) throws IOException {
        this.readModel.persist(outputPath, true);
    }

    private static final class TrainerRunnable
    implements Runnable,
    Callable<Double> {
        private final TopicModel readModel;
        private final TopicModel writeModel;
        private final Vector document;
        private final Vector docTopics;
        private final Matrix docTopicModel;
        private final int numDocTopicIters;

        private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document, Vector docTopics, Matrix docTopicModel, int numDocTopicIters) {
            this.readModel = readModel;
            this.writeModel = writeModel;
            this.document = document;
            this.docTopics = docTopics;
            this.docTopicModel = docTopicModel;
            this.numDocTopicIters = numDocTopicIters;
        }

        @Override
        public void run() {
            for (int i = 0; i < this.numDocTopicIters; ++i) {
                this.readModel.trainDocTopicModel(this.document, this.docTopics, this.docTopicModel);
            }
            if (this.writeModel != null) {
                this.writeModel.update(this.docTopicModel);
            }
        }

        @Override
        public Double call() {
            this.run();
            return this.readModel.perplexity(this.document, this.docTopics);
        }
    }
}

