package org.apache.mahout.clustering.dirichlet;

import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/clustering/dirichlet/DirichletState.class */
public class DirichletState {
    private int numClusters;
    private ModelDistribution<VectorWritable> modelFactory;
    private List<DirichletCluster> clusters;
    private Vector mixture;
    private final double alpha0;

    public DirichletState(ModelDistribution<VectorWritable> modelDistribution, int i, double d) {
        this.numClusters = i;
        this.modelFactory = modelDistribution;
        this.alpha0 = d;
        this.clusters = new ArrayList();
        for (Model<VectorWritable> model : modelDistribution.sampleFromPrior(i)) {
            this.clusters.add(new DirichletCluster((Cluster) model));
        }
        this.mixture = UncommonDistributions.rDirichlet(computeTotalCounts(), d);
    }

    public DirichletState(DistributionDescription distributionDescription, int i, double d) {
        this(distributionDescription.createModelDistribution(), i, d);
    }

    public int getNumClusters() {
        return this.numClusters;
    }

    public void setNumClusters(int i) {
        this.numClusters = i;
    }

    public ModelDistribution<VectorWritable> getModelFactory() {
        return this.modelFactory;
    }

    public void setModelFactory(ModelDistribution<VectorWritable> modelDistribution) {
        this.modelFactory = modelDistribution;
    }

    public List<DirichletCluster> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<DirichletCluster> list) {
        this.clusters = list;
    }

    public Vector getMixture() {
        return this.mixture;
    }

    public void setMixture(Vector vector) {
        this.mixture = vector;
    }

    public Vector totalCounts() {
        return computeTotalCounts();
    }

    private Vector computeTotalCounts() {
        DenseVector denseVector = new DenseVector(this.numClusters);
        for (int i = 0; i < this.numClusters; i++) {
            denseVector.set(i, this.clusters.get(i).getTotalCount());
        }
        return denseVector;
    }

    public void update(Cluster[] clusterArr) {
        for (int i = 0; i < clusterArr.length; i++) {
            clusterArr[i].computeParameters();
            this.clusters.get(i).setModel(clusterArr[i]);
        }
        this.mixture = UncommonDistributions.rDirichlet(totalCounts(), this.alpha0);
    }

    public double adjustedProbability(VectorWritable vectorWritable, int i) {
        return this.mixture.get(i) * this.clusters.get(i).getModel().pdf(vectorWritable);
    }

    public Model<VectorWritable>[] getModels() {
        Model<VectorWritable>[] modelArr = new Model[this.numClusters];
        for (int i = 0; i < this.numClusters; i++) {
            modelArr[i] = this.clusters.get(i).getModel();
        }
        return modelArr;
    }
}
