package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/classes/libarx-3.7.1.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.class
 */
/* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.class */
public class ALSWRFactorizer extends AbstractFactorizer {
    private final DataModel dataModel;
    private final int numFeatures;
    private final double lambda;
    private final int numIterations;
    private final boolean usesImplicitFeedback;
    private final double alpha;
    private final int numTrainingThreads;
    private static final double DEFAULT_ALPHA = 40.0d;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ALSWRFactorizer.class);

    /* JADX WARN: Classes with same name are omitted:
      input_file:BOOT-INF/classes/libarx-3.7.1.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer$Features.class
     */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer$Features.class */
    static class Features {
        private final DataModel dataModel;
        private final int numFeatures;
        private final double[][] M;
        private final double[][] U;

        Features(ALSWRFactorizer aLSWRFactorizer) throws TasteException {
            this.dataModel = aLSWRFactorizer.dataModel;
            this.numFeatures = aLSWRFactorizer.numFeatures;
            RandomWrapper random = RandomUtils.getRandom();
            this.M = new double[this.dataModel.getNumItems()][this.numFeatures];
            LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
            while (itemIDs.hasNext()) {
                long nextLong = itemIDs.nextLong();
                int intValue = aLSWRFactorizer.itemIndex(nextLong).intValue();
                this.M[intValue][0] = averateRating(nextLong);
                for (int i = 1; i < this.numFeatures; i++) {
                    this.M[intValue][i] = random.nextDouble() * 0.1d;
                }
            }
            this.U = new double[this.dataModel.getNumUsers()][this.numFeatures];
        }

        double[][] getM() {
            return this.M;
        }

        double[][] getU() {
            return this.U;
        }

        Vector getUserFeatureColumn(int i) {
            return new DenseVector(this.U[i]);
        }

        Vector getItemFeatureColumn(int i) {
            return new DenseVector(this.M[i]);
        }

        void setFeatureColumnInU(int i, Vector vector) {
            setFeatureColumn(this.U, i, vector);
        }

        void setFeatureColumnInM(int i, Vector vector) {
            setFeatureColumn(this.M, i, vector);
        }

        protected void setFeatureColumn(double[][] dArr, int i, Vector vector) {
            for (int i2 = 0; i2 < this.numFeatures; i2++) {
                dArr[i][i2] = vector.get(i2);
            }
        }

        protected double averateRating(long j) throws TasteException {
            PreferenceArray preferencesForItem = this.dataModel.getPreferencesForItem(j);
            FullRunningAverage fullRunningAverage = new FullRunningAverage();
            Iterator<Preference> it = preferencesForItem.iterator();
            while (it.hasNext()) {
                fullRunningAverage.addDatum(it.next().getValue());
            }
            return fullRunningAverage.getAverage();
        }
    }

    public ALSWRFactorizer(DataModel dataModel, int i, double d, int i2, boolean z, double d2, int i3) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.numFeatures = i;
        this.lambda = d;
        this.numIterations = i2;
        this.usesImplicitFeedback = z;
        this.alpha = d2;
        this.numTrainingThreads = i3;
    }

    public ALSWRFactorizer(DataModel dataModel, int i, double d, int i2, boolean z, double d2) throws TasteException {
        this(dataModel, i, d, i2, z, d2, Runtime.getRuntime().availableProcessors());
    }

    public ALSWRFactorizer(DataModel dataModel, int i, double d, int i2) throws TasteException {
        this(dataModel, i, d, i2, false, DEFAULT_ALPHA);
    }

    /* JADX WARN: Type inference failed for: r0v64, types: [org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator] */
    @Override // org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer
    public Factorization factorize() throws TasteException {
        log.info("starting to compute the factorization...");
        final Features features = new Features(this);
        OpenIntObjectHashMap<Vector> openIntObjectHashMap = null;
        OpenIntObjectHashMap<Vector> openIntObjectHashMap2 = null;
        if (this.usesImplicitFeedback) {
            openIntObjectHashMap = userFeaturesMapping(this.dataModel.getUserIDs(), this.dataModel.getNumUsers(), features.getU());
            openIntObjectHashMap2 = itemFeaturesMapping(this.dataModel.getItemIDs(), this.dataModel.getNumItems(), features.getM());
        }
        for (int i = 0; i < this.numIterations; i++) {
            log.info("iteration {}", Integer.valueOf(i));
            ExecutorService createQueue = createQueue();
            LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
            try {
                final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackAlternatingLeastSquaresSolver = this.usesImplicitFeedback ? new ImplicitFeedbackAlternatingLeastSquaresSolver(this.numFeatures, this.lambda, this.alpha, openIntObjectHashMap2) : null;
                while (userIDs.hasNext()) {
                    final long nextLong = userIDs.nextLong();
                    final ?? iterator2 = this.dataModel.getItemIDsFromUser(nextLong).iterator2();
                    final PreferenceArray preferencesFromUser = this.dataModel.getPreferencesFromUser(nextLong);
                    createQueue.execute(new Runnable() { // from class: org.apache.mahout.cf.taste.impl.recommender.svd.ALSWRFactorizer.1
                        @Override // java.lang.Runnable
                        public void run() {
                            ArrayList newArrayList = Lists.newArrayList();
                            while (iterator2.hasNext()) {
                                newArrayList.add(features.getItemFeatureColumn(ALSWRFactorizer.this.itemIndex(iterator2.nextLong()).intValue()));
                            }
                            features.setFeatureColumnInU(ALSWRFactorizer.this.userIndex(nextLong).intValue(), ALSWRFactorizer.this.usesImplicitFeedback ? implicitFeedbackAlternatingLeastSquaresSolver.solve(ALSWRFactorizer.this.sparseUserRatingVector(preferencesFromUser)) : AlternatingLeastSquaresSolver.solve(newArrayList, ALSWRFactorizer.ratingVector(preferencesFromUser), ALSWRFactorizer.this.lambda, ALSWRFactorizer.this.numFeatures));
                        }
                    });
                }
                ExecutorService createQueue2 = createQueue();
                LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
                try {
                    final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackAlternatingLeastSquaresSolver2 = this.usesImplicitFeedback ? new ImplicitFeedbackAlternatingLeastSquaresSolver(this.numFeatures, this.lambda, this.alpha, openIntObjectHashMap) : null;
                    while (itemIDs.hasNext()) {
                        final long nextLong2 = itemIDs.nextLong();
                        final PreferenceArray preferencesForItem = this.dataModel.getPreferencesForItem(nextLong2);
                        createQueue2.execute(new Runnable() { // from class: org.apache.mahout.cf.taste.impl.recommender.svd.ALSWRFactorizer.2
                            @Override // java.lang.Runnable
                            public void run() {
                                ArrayList newArrayList = Lists.newArrayList();
                                Iterator<Preference> it = preferencesForItem.iterator();
                                while (it.hasNext()) {
                                    newArrayList.add(features.getUserFeatureColumn(ALSWRFactorizer.this.userIndex(it.next().getUserID()).intValue()));
                                }
                                features.setFeatureColumnInM(ALSWRFactorizer.this.itemIndex(nextLong2).intValue(), ALSWRFactorizer.this.usesImplicitFeedback ? implicitFeedbackAlternatingLeastSquaresSolver2.solve(ALSWRFactorizer.this.sparseItemRatingVector(preferencesForItem)) : AlternatingLeastSquaresSolver.solve(newArrayList, ALSWRFactorizer.ratingVector(preferencesForItem), ALSWRFactorizer.this.lambda, ALSWRFactorizer.this.numFeatures));
                            }
                        });
                    }
                    createQueue2.shutdown();
                    try {
                        createQueue2.awaitTermination(this.dataModel.getNumItems(), TimeUnit.SECONDS);
                    } catch (InterruptedException e) {
                        log.warn("Error when computing item features", (Throwable) e);
                    }
                } catch (Throwable th) {
                    createQueue2.shutdown();
                    try {
                        createQueue2.awaitTermination(this.dataModel.getNumItems(), TimeUnit.SECONDS);
                    } catch (InterruptedException e2) {
                        log.warn("Error when computing item features", (Throwable) e2);
                    }
                    throw th;
                }
            } finally {
                createQueue.shutdown();
                try {
                    createQueue.awaitTermination(this.dataModel.getNumUsers(), TimeUnit.SECONDS);
                } catch (InterruptedException e3) {
                    log.warn("Error when computing user features", (Throwable) e3);
                }
            }
        }
        log.info("finished computation of the factorization...");
        return createFactorization(features.getU(), features.getM());
    }

    protected ExecutorService createQueue() {
        return Executors.newFixedThreadPool(this.numTrainingThreads);
    }

    protected static Vector ratingVector(PreferenceArray preferenceArray) {
        double[] dArr = new double[preferenceArray.length()];
        for (int i = 0; i < preferenceArray.length(); i++) {
            dArr[i] = preferenceArray.get(i).getValue();
        }
        return new DenseVector(dArr, true);
    }

    protected OpenIntObjectHashMap<Vector> itemFeaturesMapping(LongPrimitiveIterator longPrimitiveIterator, int i, double[][] dArr) {
        OpenIntObjectHashMap<Vector> openIntObjectHashMap = new OpenIntObjectHashMap<>(i);
        while (longPrimitiveIterator.hasNext()) {
            long longValue = ((Long) longPrimitiveIterator.next()).longValue();
            openIntObjectHashMap.put((int) longValue, new DenseVector(dArr[itemIndex(longValue).intValue()], true));
        }
        return openIntObjectHashMap;
    }

    protected OpenIntObjectHashMap<Vector> userFeaturesMapping(LongPrimitiveIterator longPrimitiveIterator, int i, double[][] dArr) {
        OpenIntObjectHashMap<Vector> openIntObjectHashMap = new OpenIntObjectHashMap<>(i);
        while (longPrimitiveIterator.hasNext()) {
            long longValue = ((Long) longPrimitiveIterator.next()).longValue();
            openIntObjectHashMap.put((int) longValue, new DenseVector(dArr[userIndex(longValue).intValue()], true));
        }
        return openIntObjectHashMap;
    }

    protected Vector sparseItemRatingVector(PreferenceArray preferenceArray) {
        SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(Integer.MAX_VALUE, preferenceArray.length());
        Iterator<Preference> it = preferenceArray.iterator();
        while (it.hasNext()) {
            sequentialAccessSparseVector.set((int) it.next().getUserID(), r0.getValue());
        }
        return sequentialAccessSparseVector;
    }

    protected Vector sparseUserRatingVector(PreferenceArray preferenceArray) {
        SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(Integer.MAX_VALUE, preferenceArray.length());
        Iterator<Preference> it = preferenceArray.iterator();
        while (it.hasNext()) {
            sequentialAccessSparseVector.set((int) it.next().getItemID(), r0.getValue());
        }
        return sequentialAccessSparseVector;
    }
}
