package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Array$;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$DoubleIsFractional$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: HingeAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001u3Qa\u0003\u0007\u0001!aA\u0001B\u000b\u0001\u0003\u0002\u0003\u0006I\u0001\f\u0005\t_\u0001\u0011\t\u0011)A\u0005a!)A\b\u0001C\u0001{!9\u0011\t\u0001b\u0001\n#\u0012\u0005B\u0002$\u0001A\u0003%1\tC\u0004H\u0001\t\u0007I\u0011\u0002\"\t\r!\u0003\u0001\u0015!\u0003D\u0011!I\u0005\u0001#b\u0001\n\u0013Q\u0005\u0002C+\u0001\u0011\u000b\u0007I\u0011\u0002,\t\u000ba\u0003A\u0011A-\u0003)\tcwnY6IS:<W-Q4he\u0016<\u0017\r^8s\u0015\tia\"\u0001\u0006bO\u001e\u0014XmZ1u_JT!a\u0004\t\u0002\u000b=\u0004H/[7\u000b\u0005E\u0011\u0012AA7m\u0015\t\u0019B#A\u0003ta\u0006\u00148N\u0003\u0002\u0016-\u00051\u0011\r]1dQ\u0016T\u0011aF\u0001\u0004_J<7c\u0001\u0001\u001a?A\u0011!$H\u0007\u00027)\tA$A\u0003tG\u0006d\u0017-\u0003\u0002\u001f7\t1\u0011I\\=SK\u001a\u0004B\u0001I\u0011$S5\tA\"\u0003\u0002#\u0019\taB)\u001b4gKJ,g\u000e^5bE2,Gj\\:t\u0003\u001e<'/Z4bi>\u0014\bC\u0001\u0013(\u001b\u0005)#B\u0001\u0014\u0011\u0003\u001d1W-\u0019;ve\u0016L!\u0001K\u0013\u0003\u001b%s7\u000f^1oG\u0016\u0014En\\2l!\t\u0001\u0003!\u0001\u0007gSRLe\u000e^3sG\u0016\u0004Ho\u0001\u0001\u0011\u0005ii\u0013B\u0001\u0018\u001c\u0005\u001d\u0011un\u001c7fC:\faBY2D_\u00164g-[2jK:$8\u000fE\u00022iYj\u0011A\r\u0006\u0003gI\t\u0011B\u0019:pC\u0012\u001c\u0017m\u001d;\n\u0005U\u0012$!\u0003\"s_\u0006$7-Y:u!\t9$(D\u00019\u0015\tI\u0004#\u0001\u0004mS:\fGnZ\u0005\u0003wa\u0012aAV3di>\u0014\u0018A\u0002\u001fj]&$h\b\u0006\u0002?\u0001R\u0011\u0011f\u0010\u0005\u0006_\r\u0001\r\u0001\r\u0005\u0006U\r\u0001\r\u0001L\u0001\u0004I&lW#A\"\u0011\u0005i!\u0015BA#\u001c\u0005\rIe\u000e^\u0001\u0005I&l\u0007%A\u0006ok64U-\u0019;ve\u0016\u001c\u0018\u0001\u00048v[\u001a+\u0017\r^;sKN\u0004\u0013!E2pK\u001a4\u0017nY5f]R\u001c\u0018I\u001d:bsV\t1\nE\u0002\u001b\u0019:K!!T\u000e\u0003\u000b\u0005\u0013(/Y=\u0011\u0005iy\u0015B\u0001)\u001c\u0005\u0019!u.\u001e2mK\"\u0012\u0001B\u0015\t\u00035MK!\u0001V\u000e\u0003\u0013Q\u0014\u0018M\\:jK:$\u0018A\u00027j]\u0016\f'/F\u00017Q\tI!+A\u0002bI\u0012$\"AW.\u000e\u0003\u0001AQ\u0001\u0018\u0006A\u0002\r\nQA\u00197pG.\u0004")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/BlockHingeAggregator.class */
public class BlockHingeAggregator implements DifferentiableLossAggregator<InstanceBlock, BlockHingeAggregator> {
    private transient double[] coefficientsArray;
    private transient Vector linear;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int dim;
    private final int numFeatures;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient byte bitmap$trans$0;
    private volatile boolean bitmap$0;

    /* JADX WARN: Type inference failed for: r0v1, types: [org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator, org.apache.spark.ml.optim.aggregator.BlockHingeAggregator] */
    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public BlockHingeAggregator merge(BlockHingeAggregator blockHingeAggregator) {
        ?? merge;
        merge = merge(blockHingeAggregator);
        return merge;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public Vector gradient() {
        Vector gradient;
        gradient = gradient();
        return gradient;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weight() {
        double weight;
        weight = weight();
        return weight;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double loss() {
        double loss;
        loss = loss();
        return loss;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weightSum() {
        return this.weightSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void weightSum_$eq(double d) {
        this.weightSum = d;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double lossSum() {
        return this.lossSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void lossSum_$eq(double d) {
        this.lossSum = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.ml.optim.aggregator.BlockHingeAggregator] */
    private double[] gradientSumArray$lzycompute() {
        double[] gradientSumArray;
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                gradientSumArray = gradientSumArray();
                this.gradientSumArray = gradientSumArray;
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double[] coefficientsArray$lzycompute() {
        synchronized (this) {
            if (((byte) (this.bitmap$trans$0 & 1)) == 0) {
                DenseVector denseVector = (Vector) this.bcCoefficients.value();
                if (denseVector instanceof DenseVector) {
                    Option unapply = DenseVector$.MODULE$.unapply(denseVector);
                    if (!unapply.isEmpty()) {
                        this.coefficientsArray = (double[]) unapply.get();
                        this.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 1);
                    }
                }
                throw new IllegalArgumentException(new StringBuilder(54).append("coefficients only supports dense vector").append(" but got type ").append(this.bcCoefficients.value().getClass()).append(".").toString());
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return ((byte) (this.bitmap$trans$0 & 1)) == 0 ? coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [org.apache.spark.ml.optim.aggregator.BlockHingeAggregator] */
    private Vector linear$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$trans$0 & 2)) == 0) {
                this.linear = Vectors$.MODULE$.dense(this.fitIntercept ? (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(coefficientsArray())).take(numFeatures()) : coefficientsArray());
                r0 = this;
                r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 2);
            }
        }
        return this.linear;
    }

    private Vector linear() {
        return ((byte) (this.bitmap$trans$0 & 2)) == 0 ? linear$lzycompute() : this.linear;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public BlockHingeAggregator add(InstanceBlock instanceBlock) {
        Predef$.MODULE$.require(instanceBlock.matrix().isTransposed());
        Predef$.MODULE$.require(numFeatures() == instanceBlock.numFeatures(), () -> {
            return new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(instanceBlock.numFeatures()).append(".").toString();
        });
        Predef$.MODULE$.require(instanceBlock.weightIter().forall(d -> {
            return d >= ((double) 0);
        }), () -> {
            return new StringBuilder(34).append("instance weights ").append(instanceBlock.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString();
        });
        if (instanceBlock.weightIter().forall(d2 -> {
            return d2 == ((double) 0);
        })) {
            return this;
        }
        int size = instanceBlock.size();
        DenseVector dense = this.fitIntercept ? Vectors$.MODULE$.dense((double[]) Array$.MODULE$.fill(size, () -> {
            return BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).last());
        }, ClassTag$.MODULE$.Double())).toDense() : Vectors$.MODULE$.zeros(size).toDense();
        BLAS$.MODULE$.gemv(1.0d, instanceBlock.matrix(), linear(), 1.0d, dense);
        double d3 = 0.0d;
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= size) {
                break;
            }
            double apply$mcDI$sp = instanceBlock.getWeight().apply$mcDI$sp(i2);
            if (apply$mcDI$sp > 0) {
                double label = instanceBlock.getLabel(i2);
                double d4 = (label + label) - 1.0d;
                double apply = (1.0d - (d4 * dense.apply(i2))) * apply$mcDI$sp;
                if (apply > 0) {
                    d3 += apply;
                    dense.values()[i2] = (-d4) * apply$mcDI$sp;
                } else {
                    dense.values()[i2] = 0.0d;
                }
            } else {
                dense.values()[i2] = 0.0d;
            }
            i = i2 + 1;
        }
        lossSum_$eq(lossSum() + d3);
        weightSum_$eq(weightSum() + BoxesRunTime.unboxToDouble(instanceBlock.weightIter().sum(Numeric$DoubleIsFractional$.MODULE$)));
        if (new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dense.values())).forall(d5 -> {
            return d5 == ((double) 0);
        })) {
            return this;
        }
        boolean z = false;
        SparseMatrix sparseMatrix = null;
        DenseMatrix matrix = instanceBlock.matrix();
        if (matrix instanceof DenseMatrix) {
            DenseMatrix denseMatrix = matrix;
            BLAS$.MODULE$.nativeBLAS().dgemv("N", denseMatrix.numCols(), denseMatrix.numRows(), 1.0d, denseMatrix.values(), denseMatrix.numCols(), dense.values(), 1, 1.0d, gradientSumArray(), 1);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            if (matrix instanceof SparseMatrix) {
                z = true;
                sparseMatrix = (SparseMatrix) matrix;
                if (this.fitIntercept) {
                    DenseVector dense2 = Vectors$.MODULE$.zeros(numFeatures()).toDense();
                    BLAS$.MODULE$.gemv(1.0d, sparseMatrix.transpose(), dense, 0.0d, dense2);
                    BLAS$.MODULE$.getBLAS(numFeatures()).daxpy(numFeatures(), 1.0d, dense2.values(), 1, gradientSumArray(), 1);
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
            }
            if (!z || this.fitIntercept) {
                throw new IllegalArgumentException(new StringBuilder(21).append("Unknown matrix type ").append(matrix.getClass()).append(".").toString());
            }
            BLAS$.MODULE$.gemv(1.0d, sparseMatrix.transpose(), dense, 1.0d, new DenseVector(gradientSumArray()));
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        if (this.fitIntercept) {
            gradientSumArray()[numFeatures()] = gradientSumArray()[numFeatures()] + BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dense.values())).sum(Numeric$DoubleIsFractional$.MODULE$));
        }
        return this;
    }

    public BlockHingeAggregator(boolean z, Broadcast<Vector> broadcast) {
        this.fitIntercept = z;
        this.bcCoefficients = broadcast;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector) broadcast.value()).size();
        this.numFeatures = z ? dim() - 1 : dim();
    }
}
