/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.lda.cvb.CVB0Driver;
import org.apache.mahout.clustering.lda.cvb.ModelTrainer;
import org.apache.mahout.clustering.lda.cvb.TopicModel;
import org.apache.mahout.common.MemoryUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CachingCVB0PerplexityMapper
extends Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable> {
    private static final Logger log = LoggerFactory.getLogger(CachingCVB0PerplexityMapper.class);
    private ModelTrainer modelTrainer;
    private TopicModel readModel;
    private int maxIters;
    private int numTopics;
    private float testFraction;
    private Random random;
    private Vector topicVector;
    private final DoubleWritable outKey = new DoubleWritable();
    private final DoubleWritable outValue = new DoubleWritable();

    protected void setup(Mapper.Context context) throws IOException, InterruptedException {
        MemoryUtil.startMemoryLogger(5000L);
        log.info("Retrieving configuration");
        Configuration conf = context.getConfiguration();
        float eta = conf.getFloat("term_topic_smoothing", Float.NaN);
        float alpha = conf.getFloat("doc_topic_smoothing", Float.NaN);
        long seed = conf.getLong("random_seed", 1234L);
        this.random = RandomUtils.getRandom(seed);
        this.numTopics = conf.getInt("num_topics", -1);
        int numTerms = conf.getInt("num_terms", -1);
        int numUpdateThreads = conf.getInt("num_update_threads", 1);
        int numTrainThreads = conf.getInt("num_train_threads", 4);
        this.maxIters = conf.getInt("max_doc_topic_iters", 10);
        float modelWeight = conf.getFloat("prev_iter_mult", 1.0f);
        this.testFraction = conf.getFloat("test_set_fraction", 0.1f);
        log.info("Initializing read model");
        Path[] modelPaths = CVB0Driver.getModelPaths(conf);
        if (modelPaths != null && modelPaths.length > 0) {
            this.readModel = new TopicModel(conf, eta, (double)alpha, null, numUpdateThreads, modelWeight, modelPaths);
        } else {
            log.info("No model files found");
            this.readModel = new TopicModel(this.numTopics, numTerms, (double)eta, alpha, RandomUtils.getRandom(seed), null, numTrainThreads, modelWeight);
        }
        log.info("Initializing model trainer");
        this.modelTrainer = new ModelTrainer(this.readModel, null, numTrainThreads, this.numTopics, numTerms);
        log.info("Initializing topic vector");
        this.topicVector = new DenseVector(new double[this.numTopics]);
    }

    protected void cleanup(Mapper.Context context) throws IOException, InterruptedException {
        this.readModel.stop();
        MemoryUtil.stopMemoryLogger();
    }

    public void map(IntWritable docId, VectorWritable document, Mapper.Context context) throws IOException, InterruptedException {
        if (this.testFraction < 1.0f && this.random.nextFloat() >= this.testFraction) {
            return;
        }
        context.getCounter((Enum)Counters.SAMPLED_DOCUMENTS).increment(1L);
        this.outKey.set(document.get().norm(1.0));
        this.outValue.set(this.modelTrainer.calculatePerplexity(document.get(), this.topicVector.assign(1.0 / (double)this.numTopics), this.maxIters));
        context.write((Object)this.outKey, (Object)this.outValue);
    }

    public static enum Counters {
        SAMPLED_DOCUMENTS;

    }
}

