package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.ToolRunner;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DistributedRowMatrixWriter;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/lib/mahout-core-0.9.jar:org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.class
 */
/* loaded from: input_file:BOOT-INF/lib/mahout-mr-0.12.2.jar:org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.class */
public class InMemoryCollapsedVariationalBayes0 extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InMemoryCollapsedVariationalBayes0.class);
    private int numTopics;
    private int numTerms;
    private int numDocuments;
    private double alpha;
    private double eta;
    private boolean verbose = false;
    private String[] terms;
    private Matrix corpusWeights;
    private double totalCorpusWeight;
    private double initialModelCorpusFraction;
    private Matrix docTopicCounts;
    private int numTrainingThreads;
    private int numUpdatingThreads;
    private ModelTrainer modelTrainer;

    private InMemoryCollapsedVariationalBayes0() {
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public InMemoryCollapsedVariationalBayes0(Matrix matrix, String[] strArr, int i, double d, double d2, int i2, int i3, double d3) {
        this.numTopics = i;
        this.alpha = d;
        this.eta = d2;
        this.corpusWeights = matrix;
        this.numDocuments = matrix.numRows();
        this.terms = strArr;
        this.initialModelCorpusFraction = d3;
        this.numTerms = strArr != null ? strArr.length : matrix.numCols();
        HashMap hashMap = new HashMap();
        if (strArr != null) {
            for (int i4 = 0; i4 < strArr.length; i4++) {
                hashMap.put(strArr[i4], Integer.valueOf(i4));
            }
        }
        this.numTrainingThreads = i2;
        this.numUpdatingThreads = i3;
        postInitCorpus();
        initializeModel();
    }

    private void postInitCorpus() {
        this.totalCorpusWeight = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.numDocuments; i2++) {
            Vector viewRow = this.corpusWeights.viewRow(i2);
            if (viewRow != null) {
                double norm = viewRow.norm(1.0d);
                if (norm != 0.0d) {
                    i += viewRow.getNumNondefaultElements();
                    this.totalCorpusWeight += norm;
                }
            }
        }
        log.info(String.format("Initializing corpus with %d docs, %d terms, %d nonzero entries, total termWeight %f", Integer.valueOf(this.numDocuments), Integer.valueOf(this.numTerms), Integer.valueOf(i), Double.valueOf(this.totalCorpusWeight)));
    }

    private void initializeModel() {
        TopicModel topicModel = new TopicModel(this.numTopics, this.numTerms, this.eta, this.alpha, RandomUtils.getRandom(), this.terms, this.numUpdatingThreads, this.initialModelCorpusFraction == 0.0d ? 1.0d : this.initialModelCorpusFraction * this.totalCorpusWeight);
        topicModel.setConf(getConf());
        TopicModel topicModel2 = this.initialModelCorpusFraction == 0.0d ? new TopicModel(this.numTopics, this.numTerms, this.eta, this.alpha, null, this.terms, this.numUpdatingThreads, 1.0d) : topicModel;
        topicModel2.setConf(getConf());
        this.docTopicCounts = new DenseMatrix(this.numDocuments, this.numTopics);
        this.docTopicCounts.assign(1.0d / this.numTopics);
        this.modelTrainer = new ModelTrainer(topicModel, topicModel2, this.numTrainingThreads, this.numTopics, this.numTerms);
    }

    public void trainDocuments() {
        trainDocuments(0.0d);
    }

    public void trainDocuments(double d) {
        long nanoTime = System.nanoTime();
        this.modelTrainer.start();
        for (int i = 0; i < this.corpusWeights.numRows(); i++) {
            if (d == 0.0d || i % (1.0d / d) != 0.0d) {
                this.modelTrainer.trainSync(this.corpusWeights.viewRow(i), new DenseVector(this.numTopics).assign(1.0d / this.numTopics), true, 10);
            }
        }
        this.modelTrainer.stop();
        logTime("train documents", System.nanoTime() - nanoTime);
    }

    public double iterateUntilConvergence(double d, int i, int i2) {
        return iterateUntilConvergence(d, i, i2, 0.0d);
    }

    public double iterateUntilConvergence(double d, int i, int i2, double d2) {
        int i3 = 0;
        double d3 = 0.0d;
        while (i3 < i2) {
            trainDocuments(d2);
            if (this.verbose) {
                log.info("model after: {}: {}", Integer.valueOf(i3), this.modelTrainer.getReadModel());
            }
            log.info("iteration {} complete", Integer.valueOf(i3));
            d3 = this.modelTrainer.calculatePerplexity(this.corpusWeights, this.docTopicCounts, d2);
            log.info("{} = perplexity", Double.valueOf(d3));
            i3++;
        }
        double d4 = 0.0d;
        double d5 = Double.MAX_VALUE;
        while (i3 < i && d5 > d) {
            trainDocuments();
            if (this.verbose) {
                log.info("model after: {}: {}", Integer.valueOf(i3), this.modelTrainer.getReadModel());
            }
            d4 = this.modelTrainer.calculatePerplexity(this.corpusWeights, this.docTopicCounts, d2);
            log.info("{} = perplexity", Double.valueOf(d4));
            i3++;
            d5 = Math.abs(d4 - d3) / d3;
            log.info("{} = fractionalChange", Double.valueOf(d5));
            d3 = d4;
        }
        if (i3 < i) {
            log.info(String.format("Converged! fractional error change: %f, error %f", Double.valueOf(d5), Double.valueOf(d4)));
        } else {
            log.info(String.format("Reached max iteration count (%d), fractional error change: %f, error: %f", Integer.valueOf(i), Double.valueOf(d5), Double.valueOf(d4)));
        }
        return d4;
    }

    public void writeModel(Path path) throws IOException {
        this.modelTrainer.persist(path);
    }

    private static void logTime(String str, long j) {
        log.info("{} time: {}ms", str, Double.valueOf(j / 1000000.0d));
    }

    public static int main2(String[] strArr, Configuration configuration) throws Exception {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        Option helpOption = DefaultOptionCreator.helpOption();
        DefaultOption create = defaultOptionBuilder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription("The Directory on HDFS containing the collapsed, properly formatted files having one doc per line").withShortName("i").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("dictionary").withRequired(false).withArgument(argumentBuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription("The path to the term-dictionary format is ... ").withShortName("d").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName("dfs").withRequired(false).withArgument(argumentBuilder.withName("dfs").withMinimum(1).withMaximum(1).create()).withDescription("HDFS namenode URI").withShortName("dfs").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("numTopics").withRequired(true).withArgument(argumentBuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription("Number of topics to learn").withShortName("top").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("topicOutputFile").withRequired(true).withArgument(argumentBuilder.withName("topicOutputFile").withMinimum(1).withMaximum(1).create()).withDescription("File to write out p(term | topic)").withShortName("to").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("docOutputFile").withRequired(true).withArgument(argumentBuilder.withName("docOutputFile").withMinimum(1).withMaximum(1).create()).withDescription("File to write out p(topic | docid)").withShortName("do").create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("alpha").withRequired(false).withArgument(argumentBuilder.withName("alpha").withMinimum(1).withMaximum(1).withDefault("0.1").create()).withDescription("Smoothing parameter for p(topic | document) prior").withShortName("a").create();
        DefaultOption create8 = defaultOptionBuilder.withLongName("eta").withRequired(false).withArgument(argumentBuilder.withName("eta").withMinimum(1).withMaximum(1).withDefault("0.1").create()).withDescription("Smoothing parameter for p(term | topic)").withShortName("e").create();
        DefaultOption create9 = defaultOptionBuilder.withLongName("maxIterations").withRequired(false).withArgument(argumentBuilder.withName("maxIterations").withMinimum(1).withMaximum(1).withDefault("10").create()).withDescription("Maximum number of training passes").withShortName(FuzzyKMeansDriver.M_OPTION).create();
        DefaultOption create10 = defaultOptionBuilder.withLongName("modelCorpusFraction").withRequired(false).withArgument(argumentBuilder.withName("modelCorpusFraction").withMinimum(1).withMaximum(1).withDefault("0.0").create()).withShortName("mcf").withDescription("For online updates, initial value of |model|/|corpus|").create();
        DefaultOption create11 = defaultOptionBuilder.withLongName("burnInIterations").withRequired(false).withArgument(argumentBuilder.withName("burnInIterations").withMinimum(1).withMaximum(1).withDefault("5").create()).withDescription("Minimum number of iterations").withShortName(WikipediaTokenizer.BOLD).create();
        DefaultOption create12 = defaultOptionBuilder.withLongName("convergence").withRequired(false).withArgument(argumentBuilder.withName("convergence").withMinimum(1).withMaximum(1).withDefault("0.0").create()).withDescription("Fractional rate of perplexity to consider convergence").withShortName(WikipediaTokenizer.CATEGORY).create();
        DefaultOption create13 = defaultOptionBuilder.withLongName("reInferDocTopics").withRequired(false).withArgument(argumentBuilder.withName("reInferDocTopics").withMinimum(1).withMaximum(1).withDefault("no").create()).withDescription("re-infer p(topic | doc) : [no | randstart | continue]").withShortName("rdt").create();
        DefaultOption create14 = defaultOptionBuilder.withLongName("numTrainThreads").withRequired(false).withArgument(argumentBuilder.withName("numTrainThreads").withMinimum(1).withMaximum(1).withDefault("1").create()).withDescription("number of threads to train with").withShortName("ntt").create();
        DefaultOption create15 = defaultOptionBuilder.withLongName("numUpdateThreads").withRequired(false).withArgument(argumentBuilder.withName("numUpdateThreads").withMinimum(1).withMaximum(1).withDefault("1").create()).withDescription("number of threads to update the model with").withShortName("nut").create();
        DefaultOption create16 = defaultOptionBuilder.withLongName("verbose").withRequired(false).withArgument(argumentBuilder.withName("verbose").withMinimum(1).withMaximum(1).withDefault("false").create()).withDescription("print verbose information, like top-terms in each topic, during iteration").withShortName("v").create();
        Group create17 = groupBuilder.withName("Options").withOption(create).withOption(create4).withOption(create7).withOption(create8).withOption(create9).withOption(create11).withOption(create12).withOption(create2).withOption(create13).withOption(create6).withOption(create5).withOption(create3).withOption(create14).withOption(create15).withOption(create10).withOption(create16).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(create17);
            parser.setHelpOption(helpOption);
            CommandLine parse = parser.parse(strArr);
            if (parse.hasOption(helpOption)) {
                CommandLineUtil.printHelp(create17);
                return -1;
            }
            String str = (String) parse.getValue(create);
            String str2 = parse.hasOption(create2) ? (String) parse.getValue(create2) : null;
            int parseInt = Integer.parseInt((String) parse.getValue(create4));
            double parseDouble = Double.parseDouble((String) parse.getValue(create7));
            double parseDouble2 = Double.parseDouble((String) parse.getValue(create8));
            int parseInt2 = Integer.parseInt((String) parse.getValue(create9));
            int parseInt3 = Integer.parseInt((String) parse.getValue(create11));
            double parseDouble3 = Double.parseDouble((String) parse.getValue(create12));
            int parseInt4 = Integer.parseInt((String) parse.getValue(create14));
            int parseInt5 = Integer.parseInt((String) parse.getValue(create15));
            String str3 = (String) parse.getValue(create5);
            String str4 = (String) parse.getValue(create6);
            boolean parseBoolean = Boolean.parseBoolean((String) parse.getValue(create16));
            double parseDouble4 = Double.parseDouble((String) parse.getValue(create10));
            long nanoTime = System.nanoTime();
            if (configuration.get("fs.default.name") == null) {
                configuration.set("fs.default.name", (String) parse.getValue(create3));
            }
            String[] loadDictionary = loadDictionary(str2, configuration);
            logTime("dictionary loading", System.nanoTime() - nanoTime);
            long nanoTime2 = System.nanoTime();
            Matrix loadVectors = loadVectors(str, configuration);
            logTime("vector seqfile corpus loading", System.nanoTime() - nanoTime2);
            long nanoTime3 = System.nanoTime();
            InMemoryCollapsedVariationalBayes0 inMemoryCollapsedVariationalBayes0 = new InMemoryCollapsedVariationalBayes0(loadVectors, loadDictionary, parseInt, parseDouble, parseDouble2, parseInt4, parseInt5, parseDouble4);
            logTime("cvb0 init", System.nanoTime() - nanoTime3);
            long nanoTime4 = System.nanoTime();
            inMemoryCollapsedVariationalBayes0.setVerbose(parseBoolean);
            inMemoryCollapsedVariationalBayes0.iterateUntilConvergence(parseDouble3, parseInt2, parseInt3);
            logTime("total training time", System.nanoTime() - nanoTime4);
            long nanoTime5 = System.nanoTime();
            inMemoryCollapsedVariationalBayes0.writeModel(new Path(str3));
            DistributedRowMatrixWriter.write(new Path(str4), configuration, inMemoryCollapsedVariationalBayes0.docTopicCounts);
            logTime("printTopics", System.nanoTime() - nanoTime5);
            return 0;
        } catch (OptionException e) {
            log.error("Error while parsing options", (Throwable) e);
            CommandLineUtil.printHelp(create17);
            return 0;
        }
    }

    private static String[] loadDictionary(String str, Configuration configuration) {
        if (str == null) {
            return null;
        }
        Path path = new Path(str);
        ArrayList<Pair> arrayList = new ArrayList();
        int i = 0;
        Iterator it = new SequenceFileIterable(path, true, configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            arrayList.add(new Pair(Integer.valueOf(((IntWritable) pair.getSecond()).get()), ((Writable) pair.getFirst()).toString()));
            i = Math.max(i, ((IntWritable) pair.getSecond()).get());
        }
        String[] strArr = new String[i + 1];
        for (Pair pair2 : arrayList) {
            strArr[((Integer) pair2.getFirst()).intValue()] = (String) pair2.getSecond();
        }
        return strArr;
    }

    @Override // org.apache.mahout.common.AbstractJob, org.apache.hadoop.conf.Configured, org.apache.hadoop.conf.Configurable
    public Configuration getConf() {
        return super.getConf();
    }

    private static Matrix loadVectors(String str, Configuration configuration) throws IOException {
        Path path = new Path(str);
        FileSystem fileSystem = path.getFileSystem(configuration);
        ArrayList arrayList = new ArrayList();
        if (fileSystem.isFile(path)) {
            arrayList.add(path);
        } else {
            for (FileStatus fileStatus : fileSystem.listStatus(path, PathFilters.logsCRCFilter())) {
                arrayList.add(fileStatus.getPath());
            }
        }
        ArrayList<Pair> arrayList2 = new ArrayList();
        int i = Integer.MIN_VALUE;
        int i2 = -1;
        boolean z = false;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Iterator it2 = new SequenceFileIterable((Path) it.next(), true, configuration).iterator();
            while (it2.hasNext()) {
                Pair pair = (Pair) it2.next();
                int i3 = ((IntWritable) pair.getFirst()).get();
                Vector vector = ((VectorWritable) pair.getSecond()).get();
                if (vector instanceof NamedVector) {
                    vector = ((NamedVector) vector).getDelegate();
                }
                if (i2 < 0) {
                    i2 = vector.size();
                    z = vector.isSequentialAccess();
                }
                arrayList2.add(Pair.of(Integer.valueOf(i3), vector));
                i = Math.max(i, i3);
            }
        }
        int i4 = i + 1;
        Vector[] vectorArr = new Vector[i4];
        for (Pair pair2 : arrayList2) {
            vectorArr[((Integer) pair2.getFirst()).intValue()] = (Vector) pair2.getSecond();
        }
        return new SparseRowMatrix(i4, i2, vectorArr, true, !z);
    }

    @Override // org.apache.hadoop.util.Tool
    public int run(String[] strArr) throws Exception {
        return main2(strArr, getConf());
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new InMemoryCollapsedVariationalBayes0(), strArr);
    }
}
