package org.apache.mahout.clustering.lda.cvb;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hbase.util.Strings;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/lib/mahout-core-0.9.jar:org/apache/mahout/clustering/lda/cvb/CVB0Driver.class
 */
/* loaded from: input_file:BOOT-INF/lib/mahout-mr-0.12.2.jar:org/apache/mahout/clustering/lda/cvb/CVB0Driver.class */
public class CVB0Driver extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CVB0Driver.class);
    public static final String NUM_TOPICS = "num_topics";
    public static final String NUM_TERMS = "num_terms";
    public static final String DOC_TOPIC_SMOOTHING = "doc_topic_smoothing";
    public static final String TERM_TOPIC_SMOOTHING = "term_topic_smoothing";
    public static final String DICTIONARY = "dictionary";
    public static final String DOC_TOPIC_OUTPUT = "doc_topic_output";
    public static final String MODEL_TEMP_DIR = "topic_model_temp_dir";
    public static final String ITERATION_BLOCK_SIZE = "iteration_block_size";
    public static final String RANDOM_SEED = "random_seed";
    public static final String TEST_SET_FRACTION = "test_set_fraction";
    public static final String NUM_TRAIN_THREADS = "num_train_threads";
    public static final String NUM_UPDATE_THREADS = "num_update_threads";
    public static final String MAX_ITERATIONS_PER_DOC = "max_doc_topic_iters";
    public static final String MODEL_WEIGHT = "prev_iter_mult";
    public static final String NUM_REDUCE_TASKS = "num_reduce_tasks";
    public static final String BACKFILL_PERPLEXITY = "backfill_perplexity";
    private static final String MODEL_PATHS = "mahout.lda.cvb.modelPath";
    private static final double DEFAULT_CONVERGENCE_DELTA = 0.0d;
    private static final double DEFAULT_DOC_TOPIC_SMOOTHING = 1.0E-4d;
    private static final double DEFAULT_TERM_TOPIC_SMOOTHING = 1.0E-4d;
    private static final int DEFAULT_ITERATION_BLOCK_SIZE = 10;
    private static final double DEFAULT_TEST_SET_FRACTION = 0.0d;
    private static final int DEFAULT_NUM_TRAIN_THREADS = 4;
    private static final int DEFAULT_NUM_UPDATE_THREADS = 1;
    private static final int DEFAULT_MAX_ITERATIONS_PER_DOC = 10;
    private static final int DEFAULT_NUM_REDUCE_TASKS = 10;

    /* JADX WARN: Classes with same name are omitted:
      input_file:BOOT-INF/lib/mahout-core-0.9.jar:org/apache/mahout/clustering/lda/cvb/CVB0Driver$DualDoubleSumReducer.class
     */
    /* loaded from: input_file:BOOT-INF/lib/mahout-mr-0.12.2.jar:org/apache/mahout/clustering/lda/cvb/CVB0Driver$DualDoubleSumReducer.class */
    public static class DualDoubleSumReducer extends Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
        private final DoubleWritable outKey = new DoubleWritable();
        private final DoubleWritable outValue = new DoubleWritable();

        @Override // org.apache.hadoop.mapreduce.Reducer
        public void run(Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable>.Context context) throws IOException, InterruptedException {
            double d = 0.0d;
            double d2 = 0.0d;
            while (context.nextKey()) {
                d += context.getCurrentKey().get();
                Iterator<DoubleWritable> it = context.getValues().iterator();
                while (it.hasNext()) {
                    d2 += it.next().get();
                }
            }
            this.outKey.set(d);
            this.outValue.set(d2);
            context.write(this.outKey, this.outValue);
        }
    }

    @Override // org.apache.hadoop.util.Tool
    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption(DefaultOptionCreator.maxIterationsOption().create());
        addOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION, "cd", "The convergence delta value", String.valueOf(0.0d));
        addOption(DefaultOptionCreator.overwriteOption().create());
        addOption(NUM_TOPICS, "k", "Number of topics to learn", true);
        addOption(NUM_TERMS, "nt", "Vocabulary size", false);
        addOption(DOC_TOPIC_SMOOTHING, "a", "Smoothing for document/topic distribution", String.valueOf(1.0E-4d));
        addOption(TERM_TOPIC_SMOOTHING, "e", "Smoothing for topic/term distribution", String.valueOf(1.0E-4d));
        addOption("dictionary", "dict", "Path to term-dictionary file(s) (glob expression supported)", false);
        addOption(DOC_TOPIC_OUTPUT, "dt", "Output path for the training doc/topic distribution", false);
        addOption(MODEL_TEMP_DIR, "mt", "Path to intermediate model path (useful for restarting)", false);
        addOption(ITERATION_BLOCK_SIZE, "block", "Number of iterations per perplexity check", String.valueOf(10));
        addOption(RANDOM_SEED, "seed", "Random seed", false);
        addOption(TEST_SET_FRACTION, "tf", "Fraction of data to hold out for testing", String.valueOf(0.0d));
        addOption(NUM_TRAIN_THREADS, "ntt", "number of threads per mapper to train with", String.valueOf(4));
        addOption(NUM_UPDATE_THREADS, "nut", "number of threads per mapper to update the model with", String.valueOf(1));
        addOption(MAX_ITERATIONS_PER_DOC, "mipd", "max number of iterations per doc for p(topic|doc) learning", String.valueOf(10));
        addOption(NUM_REDUCE_TASKS, (String) null, "number of reducers to use during model estimation", String.valueOf(10));
        addOption(buildOption(BACKFILL_PERPLEXITY, null, "enable backfilling of missing perplexity values", false, false, null));
        if (parseArguments(strArr) == null) {
            return -1;
        }
        int parseInt = Integer.parseInt(getOption(NUM_TOPICS));
        Path inputPath = getInputPath();
        Path outputPath = getOutputPath();
        int parseInt2 = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
        int parseInt3 = Integer.parseInt(getOption(ITERATION_BLOCK_SIZE));
        double parseDouble = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
        double parseDouble2 = Double.parseDouble(getOption(DOC_TOPIC_SMOOTHING));
        double parseDouble3 = Double.parseDouble(getOption(TERM_TOPIC_SMOOTHING));
        int parseInt4 = Integer.parseInt(getOption(NUM_TRAIN_THREADS));
        int parseInt5 = Integer.parseInt(getOption(NUM_UPDATE_THREADS));
        int parseInt6 = Integer.parseInt(getOption(MAX_ITERATIONS_PER_DOC));
        Path path = hasOption("dictionary") ? new Path(getOption("dictionary")) : null;
        return run(getConf(), inputPath, outputPath, parseInt, hasOption(NUM_TERMS) ? Integer.parseInt(getOption(NUM_TERMS)) : getNumTerms(getConf(), path), parseDouble2, parseDouble3, parseInt2, parseInt3, parseDouble, path, hasOption(DOC_TOPIC_OUTPUT) ? new Path(getOption(DOC_TOPIC_OUTPUT)) : null, hasOption(MODEL_TEMP_DIR) ? new Path(getOption(MODEL_TEMP_DIR)) : getTempPath("topicModelState"), hasOption(RANDOM_SEED) ? Long.parseLong(getOption(RANDOM_SEED)) : System.nanoTime() % 10000, hasOption(TEST_SET_FRACTION) ? Float.parseFloat(getOption(TEST_SET_FRACTION)) : 0.0f, parseInt4, parseInt5, parseInt6, Integer.parseInt(getOption(NUM_REDUCE_TASKS)), hasOption(BACKFILL_PERPLEXITY));
    }

    private static int getNumTerms(Configuration configuration, Path path) throws IOException {
        FileSystem fileSystem = path.getFileSystem(configuration);
        Text text = new Text();
        IntWritable intWritable = new IntWritable();
        int i = -1;
        for (FileStatus fileStatus : fileSystem.globStatus(path)) {
            SequenceFile.Reader reader = new SequenceFile.Reader(fileSystem, fileStatus.getPath(), configuration);
            while (reader.next(text, intWritable)) {
                i = Math.max(i, intWritable.get());
            }
        }
        return i + 1;
    }

    public int run(Configuration configuration, Path path, Path path2, int i, int i2, double d, double d2, int i3, int i4, double d3, Path path3, Path path4, Path path5, long j, float f, int i5, int i6, int i7, int i8, boolean z) throws ClassNotFoundException, IOException, InterruptedException {
        setConf(configuration);
        Preconditions.checkArgument(((double) f) >= 0.0d && ((double) f) <= 1.0d, "Expected 'testFraction' value in range [0, 1] but found value '%s'", Float.valueOf(f));
        Preconditions.checkArgument(!z || ((double) f) > 0.0d, "Expected 'testFraction' value in range (0, 1] but found value '%s'", Float.valueOf(f));
        log.info("Will run Collapsed Variational Bayes (0th-derivative approximation) learning for LDA on {} (numTerms: {}), finding {}-topics, with document/topic prior {}, topic/term prior {}.  Maximum iterations to run will be {}, unless the change in perplexity is less than {}.  Topic model output (p(term|topic) for each topic) will be stored {}.  Random initialization seed is {}, holding out {} of the data for perplexity check\n", path, Integer.valueOf(i2), Integer.valueOf(i), Double.valueOf(d), Double.valueOf(d2), Integer.valueOf(i3), Double.valueOf(d3), path2, Long.valueOf(j), Float.valueOf(f));
        log.info((path3 == null ? "" : "Dictionary to be used located " + path3.toString() + '\n') + (path4 == null ? "" : "p(topic|docId) will be stored " + path4.toString() + '\n'));
        FileSystem fileSystem = FileSystem.get(path5.toUri(), configuration);
        int currentIterationNumber = getCurrentIterationNumber(configuration, path5, i3);
        log.info("Current iteration number: {}", Integer.valueOf(currentIterationNumber));
        configuration.set(NUM_TOPICS, String.valueOf(i));
        configuration.set(NUM_TERMS, String.valueOf(i2));
        configuration.set(DOC_TOPIC_SMOOTHING, String.valueOf(d));
        configuration.set(TERM_TOPIC_SMOOTHING, String.valueOf(d2));
        configuration.set(RANDOM_SEED, String.valueOf(j));
        configuration.set(NUM_TRAIN_THREADS, String.valueOf(i5));
        configuration.set(NUM_UPDATE_THREADS, String.valueOf(i6));
        configuration.set(MAX_ITERATIONS_PER_DOC, String.valueOf(i7));
        configuration.set(MODEL_WEIGHT, "1");
        configuration.set(TEST_SET_FRACTION, String.valueOf(f));
        ArrayList arrayList = new ArrayList();
        for (int i9 = 1; i9 <= currentIterationNumber; i9++) {
            Path modelPath = modelPath(path5, i9);
            double readPerplexity = readPerplexity(configuration, path5, i9);
            if (Double.isNaN(readPerplexity)) {
                if (z && i9 % i4 == 0) {
                    log.info("Backfilling perplexity at iteration {}", Integer.valueOf(i9));
                    if (fileSystem.exists(modelPath)) {
                        readPerplexity = calculatePerplexity(configuration, path, modelPath, i9);
                    } else {
                        log.error("Model path '{}' does not exist; Skipping iteration {} perplexity calculation", modelPath.toString(), Integer.valueOf(i9));
                    }
                }
            }
            arrayList.add(Double.valueOf(readPerplexity));
            log.info("Perplexity at iteration {} = {}", Integer.valueOf(i9), Double.valueOf(readPerplexity));
        }
        long currentTimeMillis = System.currentTimeMillis();
        while (true) {
            if (currentIterationNumber >= i3) {
                break;
            }
            if (d3 > 0.0d) {
                double rateOfChange = rateOfChange(arrayList);
                if (rateOfChange < d3) {
                    log.info("Convergence achieved at iteration {} with perplexity {} and delta {}", Integer.valueOf(currentIterationNumber), arrayList.get(arrayList.size() - 1), Double.valueOf(rateOfChange));
                    break;
                }
            }
            currentIterationNumber++;
            log.info("About to run iteration {} of {}", Integer.valueOf(currentIterationNumber), Integer.valueOf(i3));
            Path modelPath2 = modelPath(path5, currentIterationNumber - 1);
            Path modelPath3 = modelPath(path5, currentIterationNumber);
            runIteration(configuration, path, modelPath2, modelPath3, currentIterationNumber, i3, i8);
            if (f > 0.0f && currentIterationNumber % i4 == 0) {
                arrayList.add(Double.valueOf(calculatePerplexity(configuration, path, modelPath3, currentIterationNumber)));
                log.info("Current perplexity = {}", arrayList.get(arrayList.size() - 1));
                log.info("(p_{} - p_{}) / p_0 = {}; target = {}", Integer.valueOf(currentIterationNumber), Integer.valueOf(currentIterationNumber - i4), Double.valueOf(rateOfChange(arrayList)), Double.valueOf(d3));
            }
        }
        log.info("Completed {} iterations in {} seconds", Integer.valueOf(currentIterationNumber), Long.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000));
        log.info("Perplexities: ({})", Joiner.on(Strings.DEFAULT_KEYVALUE_SEPARATOR).join(arrayList));
        Path modelPath4 = modelPath(path5, currentIterationNumber);
        Job writeTopicModel = path2 != null ? writeTopicModel(configuration, modelPath4, path2) : null;
        Job writeDocTopicInference = path4 != null ? writeDocTopicInference(configuration, path, modelPath4, path4) : null;
        if (writeTopicModel == null || writeTopicModel.waitForCompletion(true)) {
            return (writeDocTopicInference == null || writeDocTopicInference.waitForCompletion(true)) ? 0 : -1;
        }
        return -1;
    }

    private static double rateOfChange(List<Double> list) {
        int size = list.size();
        if (size < 2) {
            return Double.MAX_VALUE;
        }
        return Math.abs(list.get(size - 1).doubleValue() - list.get(size - 2).doubleValue()) / list.get(0).doubleValue();
    }

    private double calculatePerplexity(Configuration configuration, Path path, Path path2, int i) throws IOException, ClassNotFoundException, InterruptedException {
        String str = "Calculating perplexity for " + path2;
        log.info("About to run: {}", str);
        Path perplexityPath = perplexityPath(path2.getParent(), i);
        Job prepareJob = prepareJob(path, perplexityPath, CachingCVB0PerplexityMapper.class, DoubleWritable.class, DoubleWritable.class, DualDoubleSumReducer.class, DoubleWritable.class, DoubleWritable.class);
        prepareJob.setJobName(str);
        prepareJob.setCombinerClass(DualDoubleSumReducer.class);
        prepareJob.setNumReduceTasks(1);
        setModelPaths(prepareJob, path2);
        HadoopUtil.delete(configuration, perplexityPath);
        if (prepareJob.waitForCompletion(true)) {
            return readPerplexity(configuration, path2.getParent(), i);
        }
        throw new InterruptedException("Failed to calculate perplexity for: " + path2);
    }

    public static double readPerplexity(Configuration configuration, Path path, int i) throws IOException {
        Path perplexityPath = perplexityPath(path, i);
        if (!FileSystem.get(perplexityPath.toUri(), configuration).exists(perplexityPath)) {
            log.warn("Perplexity path {} does not exist, returning NaN", perplexityPath);
            return Double.NaN;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        long j = 0;
        Iterator it = new SequenceFileDirIterable(perplexityPath, PathType.LIST, PathFilters.partFilter(), null, true, configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            d2 += ((DoubleWritable) pair.getFirst()).get();
            d += ((DoubleWritable) pair.getSecond()).get();
            j++;
        }
        log.info("Read {} entries with total perplexity {} and model weight {}", Long.valueOf(j), Double.valueOf(d), Double.valueOf(d2));
        return d / d2;
    }

    private Job writeTopicModel(Configuration configuration, Path path, Path path2) throws IOException, InterruptedException, ClassNotFoundException {
        String format = String.format("Writing final topic/term distributions from %s to %s", path, path2);
        log.info("About to run: {}", format);
        Job prepareJob = prepareJob(path, path2, SequenceFileInputFormat.class, CVB0TopicTermVectorNormalizerMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, format);
        prepareJob.submit();
        return prepareJob;
    }

    private Job writeDocTopicInference(Configuration configuration, Path path, Path path2, Path path3) throws IOException, ClassNotFoundException, InterruptedException {
        String format = String.format("Writing final document/topic inference from %s to %s", path, path3);
        log.info("About to run: {}", format);
        Job prepareJob = prepareJob(path, path3, SequenceFileInputFormat.class, CVB0DocInferenceMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, format);
        FileSystem fileSystem = FileSystem.get(path.toUri(), configuration);
        if (path2 != null && fileSystem.exists(path2)) {
            FileStatus[] listStatus = fileSystem.listStatus(path2, PathFilters.partFilter());
            URI[] uriArr = new URI[listStatus.length];
            for (int i = 0; i < listStatus.length; i++) {
                uriArr[i] = listStatus[i].getPath().toUri();
            }
            DistributedCache.setCacheFiles(uriArr, configuration);
            setModelPaths(prepareJob, path2);
        }
        prepareJob.submit();
        return prepareJob;
    }

    public static Path modelPath(Path path, int i) {
        return new Path(path, "model-" + i);
    }

    public static Path perplexityPath(Path path, int i) {
        return new Path(path, "perplexity-" + i);
    }

    private static int getCurrentIterationNumber(Configuration configuration, Path path, int i) throws IOException {
        FileSystem fileSystem = FileSystem.get(path.toUri(), configuration);
        int i2 = 1;
        Path modelPath = modelPath(path, 1);
        while (true) {
            Path path2 = modelPath;
            if (!fileSystem.exists(path2) || i2 > i) {
                break;
            }
            log.info("Found previous state: {}", path2);
            i2++;
            modelPath = modelPath(path, i2);
        }
        return i2 - 1;
    }

    public void runIteration(Configuration configuration, Path path, Path path2, Path path3, int i, int i2, int i3) throws IOException, ClassNotFoundException, InterruptedException {
        String format = String.format("Iteration %d of %d, input path: %s", Integer.valueOf(i), Integer.valueOf(i2), path2);
        log.info("About to run: {}", format);
        Job prepareJob = prepareJob(path, path3, CachingCVB0Mapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class);
        prepareJob.setCombinerClass(VectorSumReducer.class);
        prepareJob.setNumReduceTasks(i3);
        prepareJob.setJobName(format);
        setModelPaths(prepareJob, path2);
        HadoopUtil.delete(configuration, path3);
        if (!prepareJob.waitForCompletion(true)) {
            throw new InterruptedException(String.format("Failed to complete iteration %d stage 1", Integer.valueOf(i)));
        }
    }

    private static void setModelPaths(Job job, Path path) throws IOException {
        Configuration configuration = job.getConfiguration();
        if (path == null || !FileSystem.get(path.toUri(), configuration).exists(path)) {
            return;
        }
        FileStatus[] listStatus = FileSystem.get(path.toUri(), configuration).listStatus(path, PathFilters.partFilter());
        Preconditions.checkState(listStatus.length > 0, "No part files found in model path '%s'", path.toString());
        String[] strArr = new String[listStatus.length];
        for (int i = 0; i < listStatus.length; i++) {
            strArr[i] = listStatus[i].getPath().toUri().toString();
        }
        configuration.setStrings(MODEL_PATHS, strArr);
    }

    public static Path[] getModelPaths(Configuration configuration) {
        String[] strings = configuration.getStrings(MODEL_PATHS);
        if (strings == null || strings.length == 0) {
            return null;
        }
        Path[] pathArr = new Path[strings.length];
        for (int i = 0; i < strings.length; i++) {
            pathArr[i] = new Path(strings[i]);
        }
        return pathArr;
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new Configuration(), new CVB0Driver(), strArr);
    }
}
