/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.hadoop.als;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import java.io.Closeable;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.hadoop.als.ALS;
import org.apache.mahout.cf.taste.hadoop.als.MultithreadedSharingMapper;
import org.apache.mahout.cf.taste.hadoop.als.SharingMapper;
import org.apache.mahout.cf.taste.hadoop.als.SolveExplicitFeedbackMapper;
import org.apache.mahout.cf.taste.hadoop.als.SolveImplicitFeedbackMapper;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
import org.apache.mahout.common.mapreduce.TransposeMapper;
import org.apache.mahout.common.mapreduce.VectorSumCombiner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.VarIntWritable;
import org.apache.mahout.math.VarLongWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelALSFactorizationJob
extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
    static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
    static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
    static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
    static final String NUM_ENTITIES = ParallelALSFactorizationJob.class.getName() + ".numEntities";
    static final String USES_LONG_IDS = ParallelALSFactorizationJob.class.getName() + ".usesLongIDs";
    static final String TOKEN_POS = ParallelALSFactorizationJob.class.getName() + ".tokenPos";
    private boolean implicitFeedback;
    private int numIterations;
    private int numFeatures;
    private double lambda;
    private double alpha;
    private int numThreadsPerSolver;
    private boolean usesLongIDs;
    private int numItems;
    private int numUsers;

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Tool)new ParallelALSFactorizationJob(), (String[])args);
    }

    public int run(String[] args) throws Exception {
        this.addInputOption();
        this.addOutputOption();
        this.addOption("lambda", null, "regularization parameter", true);
        this.addOption("implicitFeedback", null, "data consists of implicit feedback?", String.valueOf(false));
        this.addOption("alpha", null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
        this.addOption("numFeatures", null, "dimension of the feature space", true);
        this.addOption("numIterations", null, "number of iterations", true);
        this.addOption("numThreadsPerSolver", null, "threads per solver mapper", String.valueOf(1));
        this.addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
        Map<String, List<String>> parsedArgs = this.parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }
        this.numFeatures = Integer.parseInt(this.getOption("numFeatures"));
        this.numIterations = Integer.parseInt(this.getOption("numIterations"));
        this.lambda = Double.parseDouble(this.getOption("lambda"));
        this.alpha = Double.parseDouble(this.getOption("alpha"));
        this.implicitFeedback = Boolean.parseBoolean(this.getOption("implicitFeedback"));
        this.numThreadsPerSolver = Integer.parseInt(this.getOption("numThreadsPerSolver"));
        this.usesLongIDs = Boolean.parseBoolean(this.getOption("usesLongIDs", String.valueOf(false)));
        if (this.usesLongIDs) {
            Job mapUsers = this.prepareJob(this.getInputPath(), this.getOutputPath("userIDIndex"), TextInputFormat.class, MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class, VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
            mapUsers.getConfiguration().set(TOKEN_POS, String.valueOf(0));
            mapUsers.waitForCompletion(true);
            Job mapItems = this.prepareJob(this.getInputPath(), this.getOutputPath("itemIDIndex"), TextInputFormat.class, MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class, VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
            mapItems.getConfiguration().set(TOKEN_POS, String.valueOf(1));
            mapItems.waitForCompletion(true);
        }
        Job itemRatings = this.prepareJob(this.getInputPath(), this.pathToItemRatings(), TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        itemRatings.setCombinerClass(VectorSumCombiner.class);
        itemRatings.getConfiguration().set(USES_LONG_IDS, String.valueOf(this.usesLongIDs));
        boolean succeeded = itemRatings.waitForCompletion(true);
        if (!succeeded) {
            return -1;
        }
        Job userRatings = this.prepareJob(this.pathToItemRatings(), this.pathToUserRatings(), TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeUserVectorsReducer.class, IntWritable.class, VectorWritable.class);
        userRatings.setCombinerClass(MergeVectorsCombiner.class);
        succeeded = userRatings.waitForCompletion(true);
        if (!succeeded) {
            return -1;
        }
        Job averageItemRatings = this.prepareJob(this.pathToItemRatings(), this.getTempPath("averageRatings"), AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
        averageItemRatings.setCombinerClass(MergeVectorsCombiner.class);
        succeeded = averageItemRatings.waitForCompletion(true);
        if (!succeeded) {
            return -1;
        }
        Vector averageRatings = ALS.readFirstRow(this.getTempPath("averageRatings"), this.getConf());
        this.numItems = averageRatings.getNumNondefaultElements();
        this.numUsers = (int)userRatings.getCounters().findCounter((Enum)Stats.NUM_USERS).getValue();
        log.info("Found {} users and {} items", (Object)this.numUsers, (Object)this.numItems);
        this.initializeM(averageRatings);
        for (int currentIteration = 0; currentIteration < this.numIterations; ++currentIteration) {
            log.info("Recomputing U (iteration {}/{})", (Object)currentIteration, (Object)this.numIterations);
            this.runSolver(this.pathToUserRatings(), this.pathToU(currentIteration), this.pathToM(currentIteration - 1), currentIteration, "U", this.numItems);
            log.info("Recomputing M (iteration {}/{})", (Object)currentIteration, (Object)this.numIterations);
            this.runSolver(this.pathToItemRatings(), this.pathToM(currentIteration), this.pathToU(currentIteration), currentIteration, "M", this.numUsers);
        }
        return 0;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void initializeM(Vector averageRatings) throws IOException {
        RandomWrapper random = RandomUtils.getRandom();
        FileSystem fs = FileSystem.get((URI)this.pathToM(-1).toUri(), (Configuration)this.getConf());
        SequenceFile.Writer writer = null;
        try {
            writer = new SequenceFile.Writer(fs, this.getConf(), new Path(this.pathToM(-1), "part-m-00000"), IntWritable.class, VectorWritable.class);
            IntWritable index = new IntWritable();
            VectorWritable featureVector = new VectorWritable();
            for (Vector.Element e : averageRatings.nonZeroes()) {
                DenseVector row = new DenseVector(this.numFeatures);
                row.setQuick(0, e.get());
                for (int m = 1; m < this.numFeatures; ++m) {
                    row.setQuick(m, ((Random)random).nextDouble());
                }
                index.set(e.index());
                featureVector.set(row);
                writer.append((Writable)index, (Writable)featureVector);
            }
        }
        catch (Throwable throwable) {
            Closeables.close(writer, false);
            throw throwable;
        }
        Closeables.close((Closeable)writer, false);
    }

    private void runSolver(Path ratings, Path output, Path pathToUorM, int currentIteration, String matrixName, int numEntities) throws ClassNotFoundException, IOException, InterruptedException {
        FileStatus[] parts;
        String name;
        Class solverMapperClassInternal;
        SharingMapper.reset();
        if (this.implicitFeedback) {
            solverMapperClassInternal = SolveImplicitFeedbackMapper.class;
            name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + this.numIterations + "), " + '(' + this.numThreadsPerSolver + " threads, " + this.numFeatures + " features, implicit feedback)";
        } else {
            solverMapperClassInternal = SolveExplicitFeedbackMapper.class;
            name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + this.numIterations + "), " + '(' + this.numThreadsPerSolver + " threads, " + this.numFeatures + " features, explicit feedback)";
        }
        Job solverForUorI = this.prepareJob(ratings, output, SequenceFileInputFormat.class, MultithreadedSharingMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, name);
        Configuration solverConf = solverForUorI.getConfiguration();
        solverConf.set(LAMBDA, String.valueOf(this.lambda));
        solverConf.set(ALPHA, String.valueOf(this.alpha));
        solverConf.setInt(NUM_FEATURES, this.numFeatures);
        solverConf.set(NUM_ENTITIES, String.valueOf(numEntities));
        FileSystem fs = FileSystem.get((URI)pathToUorM.toUri(), (Configuration)solverConf);
        for (FileStatus part : parts = fs.listStatus(pathToUorM, PathFilters.partFilter())) {
            if (log.isDebugEnabled()) {
                log.debug("Adding {} to distributed cache", (Object)part.getPath().toString());
            }
            DistributedCache.addCacheFile((URI)part.getPath().toUri(), (Configuration)solverConf);
        }
        MultithreadedMapper.setMapperClass((Job)solverForUorI, solverMapperClassInternal);
        MultithreadedMapper.setNumberOfThreads((Job)solverForUorI, (int)this.numThreadsPerSolver);
        boolean succeeded = solverForUorI.waitForCompletion(true);
        if (!succeeded) {
            throw new IllegalStateException("Job failed!");
        }
    }

    private Path pathToM(int iteration) {
        return iteration == this.numIterations - 1 ? this.getOutputPath("M") : this.getTempPath("M-" + iteration);
    }

    private Path pathToU(int iteration) {
        return iteration == this.numIterations - 1 ? this.getOutputPath("U") : this.getTempPath("U-" + iteration);
    }

    private Path pathToItemRatings() {
        return this.getTempPath("itemRatings");
    }

    private Path pathToUserRatings() {
        return this.getOutputPath("userRatings");
    }

    static class IDMapReducer
    extends Reducer<VarIntWritable, VarLongWritable, VarIntWritable, VarLongWritable> {
        IDMapReducer() {
        }

        protected void reduce(VarIntWritable index, Iterable<VarLongWritable> ids, Reducer.Context ctx) throws IOException, InterruptedException {
            ctx.write((Object)index, (Object)ids.iterator().next());
        }
    }

    static class MapLongIDsMapper
    extends Mapper<LongWritable, Text, VarIntWritable, VarLongWritable> {
        private int tokenPos;
        private final VarIntWritable index = new VarIntWritable();
        private final VarLongWritable idWritable = new VarLongWritable();

        MapLongIDsMapper() {
        }

        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            this.tokenPos = ctx.getConfiguration().getInt(TOKEN_POS, -1);
            Preconditions.checkState(this.tokenPos >= 0);
        }

        protected void map(LongWritable key, Text line, Mapper.Context ctx) throws IOException, InterruptedException {
            String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
            long id = Long.parseLong(tokens[this.tokenPos]);
            this.index.set(TasteHadoopUtils.idToIndex(id));
            this.idWritable.set(id);
            ctx.write((Object)this.index, (Object)this.idWritable);
        }
    }

    static class AverageRatingMapper
    extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private final IntWritable firstIndex = new IntWritable(0);
        private final Vector featureVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
        private final VectorWritable featureVectorWritable = new VectorWritable();

        AverageRatingMapper() {
        }

        protected void map(IntWritable r, VectorWritable v, Mapper.Context ctx) throws IOException, InterruptedException {
            FullRunningAverage avg = new FullRunningAverage();
            for (Vector.Element e : v.get().nonZeroes()) {
                avg.addDatum(e.get());
            }
            this.featureVector.setQuick(r.get(), avg.getAverage());
            this.featureVectorWritable.set(this.featureVector);
            ctx.write((Object)this.firstIndex, (Object)this.featureVectorWritable);
            this.featureVector.setQuick(r.get(), 0.0);
        }
    }

    static class ItemRatingVectorsMapper
    extends Mapper<LongWritable, Text, IntWritable, VectorWritable> {
        private final IntWritable itemIDWritable = new IntWritable();
        private final VectorWritable ratingsWritable = new VectorWritable(true);
        private final Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
        private boolean usesLongIDs;

        ItemRatingVectorsMapper() {
        }

        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            this.usesLongIDs = ctx.getConfiguration().getBoolean(USES_LONG_IDS, false);
        }

        protected void map(LongWritable offset, Text line, Mapper.Context ctx) throws IOException, InterruptedException {
            String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
            int userID = TasteHadoopUtils.readID(tokens[0], this.usesLongIDs);
            int itemID = TasteHadoopUtils.readID(tokens[1], this.usesLongIDs);
            float rating = Float.parseFloat(tokens[2]);
            this.ratings.setQuick(userID, rating);
            this.itemIDWritable.set(itemID);
            this.ratingsWritable.set(this.ratings);
            ctx.write((Object)this.itemIDWritable, (Object)this.ratingsWritable);
            this.ratings.setQuick(userID, 0.0);
        }
    }

    static class MergeUserVectorsReducer
    extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
        private final VectorWritable result = new VectorWritable();

        MergeUserVectorsReducer() {
        }

        public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Reducer.Context ctx) throws IOException, InterruptedException {
            Vector merged = VectorWritable.merge(vectors.iterator()).get();
            this.result.set(new SequentialAccessSparseVector(merged));
            ctx.write(key, (Object)this.result);
            ctx.getCounter((Enum)Stats.NUM_USERS).increment(1L);
        }
    }

    static class VectorSumReducer
    extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
        private final VectorWritable result = new VectorWritable();

        VectorSumReducer() {
        }

        protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Reducer.Context ctx) throws IOException, InterruptedException {
            Vector sum = Vectors.sum(values.iterator());
            this.result.set(new SequentialAccessSparseVector(sum));
            ctx.write(key, (Object)this.result);
        }
    }

    static enum Stats {
        NUM_USERS;

    }
}

