package edu.cmu.ml.rtw.pra.features;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.cmu.graphchi.preprocessing.VertexIdTranslate;
import edu.cmu.graphchi.util.IdCount;
import edu.cmu.graphchi.walks.distributions.DiscreteDistribution;
import edu.cmu.graphchi.walks.distributions.TwoKeyCompanion;
import edu.cmu.ml.rtw.pra.experiments.Instance;
import edu.cmu.ml.rtw.pra.features.RandomWalkPathFollower;
import edu.cmu.ml.rtw.pra.graphs.GraphOnDisk;
import edu.cmu.ml.rtw.users.matt.util.MapUtil;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import java.rmi.RemoteException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/features/RandomWalkPathFollowerCompanion.class */
public class RandomWalkPathFollowerCompanion extends TwoKeyCompanion {
    private MatrixRowPolicy acceptPolicy;
    private Set<Integer> allowedTargets;
    private VertexIdTranslate translate;
    private int[] sourceVertexIds;
    private PathType[] pathTypes;
    private final boolean normalizeWalkProbabilities;
    private final GraphOnDisk graph;

    public RandomWalkPathFollowerCompanion(GraphOnDisk graphOnDisk, int i, long j, VertexIdTranslate vertexIdTranslate, PathType[] pathTypeArr, MatrixRowPolicy matrixRowPolicy, Set<Integer> set, boolean z) throws RemoteException {
        super(i, j);
        this.translate = vertexIdTranslate;
        this.pathTypes = pathTypeArr;
        this.acceptPolicy = matrixRowPolicy;
        this.allowedTargets = set;
        this.normalizeWalkProbabilities = z;
        this.graph = graphOnDisk;
    }

    public void setSources(int[] iArr) {
        this.sourceVertexIds = iArr;
    }

    protected int getFirstKey(long j, int i) {
        return this.translate.backward(this.sourceVertexIds[RandomWalkPathFollower.staticSourceIdx(j)]);
    }

    protected int getSecondKey(long j, int i) {
        return RandomWalkPathFollower.Manager.pathType(j);
    }

    protected int getValue(long j, int i) {
        return this.translate.backward(i);
    }

    protected boolean ignoreWalk(long j) {
        return false;
    }

    @VisibleForTesting
    protected void setAcceptPolicy(MatrixRowPolicy matrixRowPolicy) {
        this.acceptPolicy = matrixRowPolicy;
    }

    @VisibleForTesting
    protected MatrixRowPolicy getAcceptPolicy() {
        return this.acceptPolicy;
    }

    @VisibleForTesting
    protected boolean getNormalizeWalks() {
        return this.normalizeWalkProbabilities;
    }

    @VisibleForTesting
    protected void setAllowedTargets(Set<Integer> set) {
        this.allowedTargets = set;
    }

    @VisibleForTesting
    protected void setDistributions(ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, DiscreteDistribution>> concurrentHashMap) {
        this.distributions = concurrentHashMap;
    }

    public void outputDistributions(String str) throws RemoteException {
    }

    @VisibleForTesting
    protected boolean acceptableRow(int i, int i2, Set<Integer> set, Set<Integer> set2) {
        if (this.acceptPolicy == MatrixRowPolicy.EVERYTHING) {
            return true;
        }
        if (this.acceptPolicy == MatrixRowPolicy.ALL_TARGETS) {
            return this.allowedTargets != null ? this.allowedTargets.contains(Integer.valueOf(i2)) : set2.contains(Integer.valueOf(i2));
        }
        if (this.acceptPolicy == MatrixRowPolicy.PAIRED_TARGETS_ONLY) {
            return set.contains(Integer.valueOf(i2));
        }
        throw new RuntimeException("Accept policy not set to something recognizable: " + this.acceptPolicy);
    }

    public FeatureMatrix getFeatureMatrix(List<Instance> list) {
        HashSet newHashSet = Sets.newHashSet();
        for (Instance instance : list) {
            if (instance.isPositive()) {
                newHashSet.add(Pair.makePair(Integer.valueOf(instance.source()), Integer.valueOf(instance.target())));
            }
        }
        HashMap newHashMap = Maps.newHashMap();
        for (Instance instance2 : list) {
            MapUtil.addValueToKeySet(newHashMap, Integer.valueOf(instance2.source()), Integer.valueOf(instance2.target()));
        }
        logger.info("Waiting for execution to finish");
        waitForFinish();
        Iterator it = this.buffers.keySet().iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next();
            Iterator it2 = ((ConcurrentHashMap) this.buffers.get(num)).keySet().iterator();
            while (it2.hasNext()) {
                drainBuffer(num.intValue(), ((Integer) it2.next()).intValue());
            }
        }
        HashSet newHashSet2 = Sets.newHashSet();
        HashSet hashSet = new HashSet();
        Iterator it3 = newHashMap.keySet().iterator();
        while (it3.hasNext()) {
            Iterator it4 = ((Set) MapUtil.getWithDefault(newHashMap, Integer.valueOf(((Integer) it3.next()).intValue()), newHashSet2)).iterator();
            while (it4.hasNext()) {
                hashSet.add(Integer.valueOf(((Integer) it4.next()).intValue()));
            }
        }
        HashMap hashMap = new HashMap();
        Iterator it5 = this.distributions.keySet().iterator();
        while (it5.hasNext()) {
            Integer num2 = (Integer) it5.next();
            ConcurrentHashMap concurrentHashMap = (ConcurrentHashMap) this.distributions.get(num2);
            Set<Integer> set = (Set) MapUtil.getWithDefault(newHashMap, num2, newHashSet2);
            Iterator it6 = concurrentHashMap.keySet().iterator();
            while (it6.hasNext()) {
                Integer num3 = (Integer) it6.next();
                DiscreteDistribution discreteDistribution = (DiscreteDistribution) concurrentHashMap.get(num3);
                double d = discreteDistribution.totalCount();
                for (IdCount idCount : discreteDistribution.getTop(discreteDistribution.size())) {
                    int i = idCount.id;
                    if (acceptableRow(num2.intValue(), i, set, hashSet)) {
                        double d2 = idCount.count;
                        if (this.normalizeWalkProbabilities) {
                            d2 /= d;
                        }
                        if (!Double.isInfinite(d2)) {
                            MapUtil.addValueToKeyList(hashMap, new Pair(num2, Integer.valueOf(i)), new Pair(num3, Double.valueOf(d2)));
                        }
                    }
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        for (Pair pair : hashMap.keySet()) {
            int intValue = ((Integer) pair.getLeft()).intValue();
            int intValue2 = ((Integer) pair.getRight()).intValue();
            List list2 = (List) hashMap.get(pair);
            int[] iArr = new int[list2.size()];
            double[] dArr = new double[list2.size()];
            for (int i2 = 0; i2 < list2.size(); i2++) {
                iArr[i2] = ((Integer) ((Pair) list2.get(i2)).getLeft()).intValue();
                dArr[i2] = ((Double) ((Pair) list2.get(i2)).getRight()).doubleValue();
            }
            boolean z = false;
            if (newHashSet.contains(pair)) {
                z = true;
            }
            arrayList.add(new MatrixRow(new Instance(intValue, intValue2, z, this.graph), iArr, dArr));
        }
        return new FeatureMatrix(arrayList);
    }
}
