package org.apache.mahout.classifier.sgd;

import com.google.common.base.Preconditions;
import java.util.Iterator;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/lib/mahout-core-0.9.jar:org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.class
 */
/* loaded from: input_file:BOOT-INF/lib/mahout-mr-0.12.2.jar:org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.class */
public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner {
    protected Matrix beta;
    protected int numCategories;
    protected int step;
    protected Vector updateSteps;
    protected Vector updateCounts;
    protected PriorFunction prior;
    private boolean sealed;
    private double lambda = 1.0E-5d;
    private Gradient gradient = new DefaultGradient();

    public AbstractOnlineLogisticRegression lambda(double d) {
        this.lambda = d;
        return this;
    }

    public static Vector link(Vector vector) {
        double maxValue = vector.maxValue();
        if (maxValue >= 40.0d) {
            vector.assign(Functions.minus(maxValue)).assign(Functions.EXP);
            return vector.divide(vector.norm(1.0d));
        }
        vector.assign(Functions.EXP);
        return vector.divide(1.0d + vector.norm(1.0d));
    }

    public static double link(double d) {
        if (d >= 0.0d) {
            return 1.0d / (1.0d + Math.exp(-d));
        }
        double exp = Math.exp(d);
        return exp / (1.0d + exp);
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classifyNoLink(Vector vector) {
        regularize(vector);
        return this.beta.times(vector);
    }

    public double classifyScalarNoLink(Vector vector) {
        return this.beta.viewRow(0).dot(vector);
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classify(Vector vector) {
        return link(classifyNoLink(vector));
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public double classifyScalar(Vector vector) {
        Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
        regularize(vector);
        return link(classifyScalarNoLink(vector));
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        unseal();
        double currentLearningRate = currentLearningRate();
        regularize(vector);
        Vector apply = this.gradient.apply(str, i, vector, this);
        for (int i2 = 0; i2 < this.numCategories - 1; i2++) {
            double d = apply.get(i2);
            Iterator<Vector.Element> it = vector.nonZeroes().iterator();
            while (it.hasNext()) {
                int index = it.next().index();
                this.beta.setQuick(i2, index, this.beta.getQuick(i2, index) + (d * currentLearningRate * perTermLearningRate(index) * vector.get(index)));
            }
        }
        Iterator<Vector.Element> it2 = vector.nonZeroes().iterator();
        while (it2.hasNext()) {
            int index2 = it2.next().index();
            this.updateSteps.setQuick(index2, getStep());
            this.updateCounts.incrementQuick(index2, 1.0d);
        }
        nextStep();
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        train(j, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        train(0L, null, i, vector);
    }

    public void regularize(Vector vector) {
        if (this.updateSteps == null || isSealed()) {
            return;
        }
        double currentLearningRate = currentLearningRate();
        for (int i = 0; i < this.numCategories - 1; i++) {
            Iterator<Vector.Element> it = vector.nonZeroes().iterator();
            while (it.hasNext()) {
                int index = it.next().index();
                double step = getStep() - this.updateSteps.get(index);
                if (step > 0.0d) {
                    this.beta.set(i, index, this.prior.age(this.beta.get(i, index), step, getLambda() * currentLearningRate * perTermLearningRate(index)));
                    this.updateSteps.set(index, getStep());
                }
            }
        }
    }

    public abstract double perTermLearningRate(int i);

    public abstract double currentLearningRate();

    public void setPrior(PriorFunction priorFunction) {
        this.prior = priorFunction;
    }

    public void setGradient(Gradient gradient) {
        this.gradient = gradient;
    }

    public PriorFunction getPrior() {
        return this.prior;
    }

    public Matrix getBeta() {
        close();
        return this.beta;
    }

    public void setBeta(int i, int i2, double d) {
        this.beta.set(i, i2, d);
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public int numCategories() {
        return this.numCategories;
    }

    public int numFeatures() {
        return this.beta.numCols();
    }

    public double getLambda() {
        return this.lambda;
    }

    public int getStep() {
        return this.step;
    }

    protected void nextStep() {
        this.step++;
    }

    public boolean isSealed() {
        return this.sealed;
    }

    protected void unseal() {
        this.sealed = false;
    }

    private void regularizeAll() {
        DenseVector denseVector = new DenseVector(this.beta.numCols());
        denseVector.assign(1.0d);
        regularize(denseVector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.sealed) {
            return;
        }
        this.step++;
        regularizeAll();
        this.sealed = true;
    }

    public void copyFrom(AbstractOnlineLogisticRegression abstractOnlineLogisticRegression) {
        Preconditions.checkArgument(this.numCategories == abstractOnlineLogisticRegression.numCategories, "Can't copy unless number of target categories is the same");
        this.beta.assign(abstractOnlineLogisticRegression.beta);
        this.step = abstractOnlineLogisticRegression.step;
        this.updateSteps.assign(abstractOnlineLogisticRegression.updateSteps);
        this.updateCounts.assign(abstractOnlineLogisticRegression.updateCounts);
    }

    public boolean validModel() {
        return this.beta.aggregate(Functions.PLUS, new DoubleFunction() { // from class: org.apache.mahout.classifier.sgd.AbstractOnlineLogisticRegression.1
            @Override // org.apache.mahout.math.function.DoubleFunction
            public double apply(double d) {
                return (Double.isNaN(d) || Double.isInfinite(d)) ? 1.0d : 0.0d;
            }
        }) < 1.0d;
    }
}
