package org.apache.spark.ml.regression;

import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.optimization.Updater;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;

/* compiled from: FMRegressor.scala */
/* loaded from: input_file:org/apache/spark/ml/regression/FactorizationMachines$.class */
public final class FactorizationMachines$ implements Serializable {
    public static FactorizationMachines$ MODULE$;
    private final String GD;
    private final String AdamW;
    private final String[] supportedSolvers;
    private final String LogisticLoss;
    private final String SquaredError;
    private final String[] supportedRegressorLosses;
    private final String[] supportedClassifierLosses;
    private final String[] supportedLosses;

    static {
        new FactorizationMachines$();
    }

    public String GD() {
        return this.GD;
    }

    public String AdamW() {
        return this.AdamW;
    }

    public String[] supportedSolvers() {
        return this.supportedSolvers;
    }

    public String LogisticLoss() {
        return this.LogisticLoss;
    }

    public String SquaredError() {
        return this.SquaredError;
    }

    public String[] supportedRegressorLosses() {
        return this.supportedRegressorLosses;
    }

    public String[] supportedClassifierLosses() {
        return this.supportedClassifierLosses;
    }

    public String[] supportedLosses() {
        return this.supportedLosses;
    }

    public Updater parseSolver(String str, int i) {
        String GD = GD();
        if (GD != null ? GD.equals(str) : str == null) {
            return new SquaredL2Updater();
        }
        String AdamW = AdamW();
        if (AdamW != null ? !AdamW.equals(str) : str != null) {
            throw new MatchError(str);
        }
        return new AdamWUpdater(i);
    }

    public BaseFactorizationMachinesGradient parseLoss(String str, int i, boolean z, boolean z2, int i2) {
        String LogisticLoss = LogisticLoss();
        if (LogisticLoss != null ? LogisticLoss.equals(str) : str == null) {
            return new LogisticFactorizationMachinesGradient(i, z, z2, i2);
        }
        String SquaredError = SquaredError();
        if (SquaredError != null ? !SquaredError.equals(str) : str != null) {
            throw new IllegalArgumentException(new StringBuilder(35).append("loss function type ").append(str).append(" is invalidation").toString());
        }
        return new MSEFactorizationMachinesGradient(i, z, z2, i2);
    }

    public Tuple3<Object, Vector, Matrix> splitCoefficients(Vector vector, int i, int i2, boolean z, boolean z2) {
        int i3 = (i * i2) + (z2 ? i : 0) + (z ? 1 : 0);
        Predef$.MODULE$.require(i3 == vector.size(), () -> {
            return new StringBuilder(50).append("coefficients.size did not match the excepted size ").append(i3).toString();
        });
        return new Tuple3<>(BoxesRunTime.boxToDouble(z ? vector.apply(vector.size() - 1) : 0.0d), z2 ? new DenseVector((double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vector.toArray())).slice(i * i2, (i * i2) + i)) : Vectors$.MODULE$.sparse(i, Nil$.MODULE$), new DenseMatrix(i, i2, (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vector.toArray())).slice(0, i * i2), true));
    }

    public Vector combineCoefficients(double d, Vector vector, Matrix matrix, boolean z, boolean z2) {
        return new DenseVector((double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps((double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(matrix.toDense().values())).$plus$plus(z2 ? new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vector.toArray())) : new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(Array$.MODULE$.emptyDoubleArray())), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).$plus$plus(z ? new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(new double[]{d})) : new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(Array$.MODULE$.emptyDoubleArray())), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())));
    }

    public double getRawPrediction(Vector vector, double d, Vector vector2, Matrix matrix) {
        DoubleRef create = DoubleRef.create(d + vector.dot(vector2));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), matrix.numCols()).foreach$mVc$sp(i -> {
            DoubleRef create2 = DoubleRef.create(0.0d);
            DoubleRef create3 = DoubleRef.create(0.0d);
            vector.foreachNonZero((i, d2) -> {
                Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i, d2);
                if (spVar == null) {
                    throw new MatchError(spVar);
                }
                double apply = matrix.apply(spVar._1$mcI$sp(), i) * spVar._2$mcD$sp();
                create2.elem += apply * apply;
                create3.elem += apply;
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
            create.elem += 0.5d * ((create3.elem * create3.elem) - create2.elem);
        });
        return create.elem;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private FactorizationMachines$() {
        MODULE$ = this;
        this.GD = "gd";
        this.AdamW = "adamW";
        this.supportedSolvers = new String[]{GD(), AdamW()};
        this.LogisticLoss = "logisticLoss";
        this.SquaredError = "squaredError";
        this.supportedRegressorLosses = new String[]{SquaredError()};
        this.supportedClassifierLosses = new String[]{LogisticLoss()};
        this.supportedLosses = (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(supportedRegressorLosses())).$plus$plus(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(supportedClassifierLosses())), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
    }
}
