package edu.cmu.ml.rtw.pra.features;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.cmu.graphchi.preprocessing.VertexIdTranslate;
import edu.cmu.graphchi.util.IdCount;
import edu.cmu.graphchi.walks.distributions.DiscreteDistribution;
import edu.cmu.graphchi.walks.distributions.TwoKeyCompanion;
import edu.cmu.ml.rtw.pra.experiments.Instance;
import edu.cmu.ml.rtw.pra.features.RandomWalkPathFinder;
import edu.cmu.ml.rtw.pra.graphs.GraphOnDisk;
import edu.cmu.ml.rtw.users.matt.util.Index;
import edu.cmu.ml.rtw.users.matt.util.MapUtil;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import java.rmi.RemoteException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/features/RandomWalkPathFinderCompanion.class */
public class RandomWalkPathFinderCompanion extends TwoKeyCompanion {
    private VertexIdTranslate translate;
    private int[] sourceVertexIds;
    private Index<PathType> pathDict;
    private PathTypePolicy policy;
    private PathTypeFactory pathTypeFactory;
    private final GraphOnDisk graph;

    public RandomWalkPathFinderCompanion(GraphOnDisk graphOnDisk, int i, long j, VertexIdTranslate vertexIdTranslate, Index<PathType> index, PathTypeFactory pathTypeFactory, PathTypePolicy pathTypePolicy) throws RemoteException {
        super(i, j);
        this.graph = graphOnDisk;
        this.translate = vertexIdTranslate;
        this.pathDict = index;
        this.pathTypeFactory = pathTypeFactory;
        this.policy = pathTypePolicy;
    }

    public void setSources(int[] iArr) {
        this.sourceVertexIds = iArr;
    }

    public void setPolicy(PathTypePolicy pathTypePolicy) {
        this.policy = pathTypePolicy;
    }

    protected int getFirstKey(long j, int i) {
        return this.translate.backward(i);
    }

    protected int getSecondKey(long j, int i) {
        return this.translate.backward(this.sourceVertexIds[RandomWalkPathFinder.staticSourceIdx(j)]);
    }

    protected int getValue(long j, int i) {
        return RandomWalkPathFinder.Manager.pathType(j);
    }

    @VisibleForTesting
    protected void setDistributions(ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, DiscreteDistribution>> concurrentHashMap) {
        this.distributions = concurrentHashMap;
    }

    public void outputDistributions(String str) throws RemoteException {
    }

    private void assureReady() {
        waitForFinish();
        Iterator it = this.buffers.keySet().iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next();
            Iterator it2 = ((ConcurrentHashMap) this.buffers.get(num)).keySet().iterator();
            while (it2.hasNext()) {
                drainBuffer(num.intValue(), ((Integer) it2.next()).intValue());
            }
        }
    }

    public Map<PathType, Integer> getPathCounts(List<Integer> list, List<Integer> list2) {
        logger.info("Waiting for finish");
        assureReady();
        logger.info("Getting paths");
        HashSet newHashSet = Sets.newHashSet(list);
        HashSet newHashSet2 = Sets.newHashSet(list2);
        HashMap newHashMap = Maps.newHashMap();
        Iterator it = this.distributions.keySet().iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next();
            ConcurrentHashMap concurrentHashMap = (ConcurrentHashMap) this.distributions.get(num);
            ConcurrentHashMap.KeySetView keySet = concurrentHashMap.keySet();
            HashSet<Integer> newHashSet3 = Sets.newHashSet(keySet);
            newHashSet3.retainAll(newHashSet);
            HashSet<Integer> newHashSet4 = Sets.newHashSet(keySet);
            newHashSet4.retainAll(newHashSet2);
            if (newHashSet2.contains(num)) {
                if (this.policy == PathTypePolicy.PAIRED_ONLY) {
                    int intValue = list.get(list2.indexOf(num)).intValue();
                    if (concurrentHashMap.containsKey(Integer.valueOf(intValue))) {
                        incrementCounts(newHashMap, (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue)), this.pathTypeFactory.emptyPathType());
                    }
                } else {
                    if (this.policy != PathTypePolicy.EVERYTHING) {
                        throw new RuntimeException("Unknown path type policy: " + this.policy);
                    }
                    Iterator it2 = newHashSet3.iterator();
                    while (it2.hasNext()) {
                        incrementCounts(newHashMap, (DiscreteDistribution) concurrentHashMap.get((Integer) it2.next()), this.pathTypeFactory.emptyPathType());
                    }
                }
            }
            if (newHashSet.contains(num)) {
                if (this.policy == PathTypePolicy.PAIRED_ONLY) {
                    int intValue2 = list2.get(list.indexOf(num)).intValue();
                    if (concurrentHashMap.containsKey(Integer.valueOf(intValue2))) {
                        incrementCounts(newHashMap, this.pathTypeFactory.emptyPathType(), (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue2)));
                    }
                } else {
                    if (this.policy != PathTypePolicy.EVERYTHING) {
                        throw new RuntimeException("Unknown path type policy: " + this.policy);
                    }
                    Iterator it3 = newHashSet4.iterator();
                    while (it3.hasNext()) {
                        incrementCounts(newHashMap, this.pathTypeFactory.emptyPathType(), (DiscreteDistribution) concurrentHashMap.get((Integer) it3.next()));
                    }
                }
            }
            for (int i = 0; i < list.size(); i++) {
                if (this.policy == PathTypePolicy.PAIRED_ONLY) {
                    int intValue3 = list.get(i).intValue();
                    int intValue4 = list2.get(i).intValue();
                    if (keySet.contains(Integer.valueOf(intValue3)) && keySet.contains(Integer.valueOf(intValue4))) {
                        incrementCounts(newHashMap, (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue3)), (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue4)));
                    }
                } else {
                    if (this.policy != PathTypePolicy.EVERYTHING) {
                        throw new RuntimeException("Unknown path type policy: " + this.policy);
                    }
                    int i2 = 0;
                    for (Integer num2 : newHashSet3) {
                        i2++;
                        if (i2 >= 10) {
                            break;
                        }
                        int i3 = 0;
                        for (Integer num3 : newHashSet4) {
                            i3++;
                            if (i3 >= 10) {
                                break;
                            }
                            if (newHashSet.contains(num2) && newHashSet2.contains(num3)) {
                                incrementCounts(newHashMap, (DiscreteDistribution) concurrentHashMap.get(num2), (DiscreteDistribution) concurrentHashMap.get(num3));
                            }
                        }
                    }
                }
            }
        }
        return newHashMap;
    }

    public Map<Instance, Map<PathType, Integer>> getPathCountMap(List<Instance> list) {
        logger.info("Waiting for finish");
        assureReady();
        logger.info("Getting paths");
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        HashMap newHashMap = Maps.newHashMap();
        for (Instance instance : list) {
            newArrayList.add(Integer.valueOf(instance.source()));
            newArrayList2.add(Integer.valueOf(instance.target()));
            newHashMap.put(Pair.makePair(Integer.valueOf(instance.source()), Integer.valueOf(instance.target())), instance);
        }
        HashSet newHashSet = Sets.newHashSet(newArrayList);
        HashSet newHashSet2 = Sets.newHashSet(newArrayList2);
        HashMap newHashMap2 = Maps.newHashMap();
        Iterator it = this.distributions.keySet().iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next();
            ConcurrentHashMap concurrentHashMap = (ConcurrentHashMap) this.distributions.get(num);
            ConcurrentHashMap.KeySetView keySet = concurrentHashMap.keySet();
            Sets.newHashSet(keySet).retainAll(newHashSet);
            Sets.newHashSet(keySet).retainAll(newHashSet2);
            if (newHashSet2.contains(num)) {
                int intValue = ((Integer) newArrayList.get(newArrayList2.indexOf(num))).intValue();
                if (concurrentHashMap.containsKey(Integer.valueOf(intValue))) {
                    incrementCounts((Map<PathType, Integer>) MapUtil.getWithDefaultAndAdd(newHashMap2, new Pair(Integer.valueOf(intValue), num), Maps.newHashMap()), (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue)), this.pathTypeFactory.emptyPathType());
                }
            }
            if (newHashSet.contains(num)) {
                int intValue2 = ((Integer) newArrayList2.get(newArrayList.indexOf(num))).intValue();
                if (concurrentHashMap.containsKey(Integer.valueOf(intValue2))) {
                    incrementCounts((Map<PathType, Integer>) MapUtil.getWithDefaultAndAdd(newHashMap2, new Pair(num, Integer.valueOf(intValue2)), Maps.newHashMap()), this.pathTypeFactory.emptyPathType(), (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue2)));
                }
            }
            for (int i = 0; i < newArrayList.size(); i++) {
                int intValue3 = ((Integer) newArrayList.get(i)).intValue();
                int intValue4 = ((Integer) newArrayList2.get(i)).intValue();
                if (keySet.contains(Integer.valueOf(intValue3)) && keySet.contains(Integer.valueOf(intValue4))) {
                    incrementCounts((Map<PathType, Integer>) MapUtil.getWithDefaultAndAdd(newHashMap2, new Pair(Integer.valueOf(intValue3), Integer.valueOf(intValue4)), Maps.newHashMap()), (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue3)), (DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue4)));
                }
            }
        }
        HashMap newHashMap3 = Maps.newHashMap();
        for (Map.Entry entry : newHashMap2.entrySet()) {
            Instance instance2 = (Instance) newHashMap.get(entry.getKey());
            if (instance2 == null) {
                instance2 = new Instance(((Integer) ((Pair) entry.getKey()).getLeft()).intValue(), ((Integer) ((Pair) entry.getKey()).getRight()).intValue(), false, this.graph);
            }
            newHashMap3.put(instance2, entry.getValue());
        }
        return newHashMap3;
    }

    public Map<Instance, Map<PathType, Set<Pair<Integer, Integer>>>> getLocalSubgraphs(List<Instance> list) {
        logger.info("Waiting for finish");
        assureReady();
        logger.info("Getting paths");
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        for (Instance instance : list) {
            newArrayList.add(Integer.valueOf(instance.source()));
            newArrayList2.add(Integer.valueOf(instance.target()));
        }
        Map<Instance, Map<PathType, Integer>> pathCountMap = getPathCountMap(list);
        HashMap newHashMap = Maps.newHashMap();
        Iterator it = this.distributions.keySet().iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next();
            ConcurrentHashMap concurrentHashMap = (ConcurrentHashMap) this.distributions.get(num);
            Iterator it2 = concurrentHashMap.keySet().iterator();
            while (it2.hasNext()) {
                int intValue = ((Integer) it2.next()).intValue();
                for (IdCount idCount : ((DiscreteDistribution) concurrentHashMap.get(Integer.valueOf(intValue))).getTop(10)) {
                    PathType pathType = (PathType) this.pathDict.getKey(idCount.id);
                    Map map = (Map) newHashMap.get(Integer.valueOf(intValue));
                    if (map == null) {
                        map = Maps.newHashMap();
                        newHashMap.put(Integer.valueOf(intValue), map);
                    }
                    MapUtil.addValueToKeySet(map, pathType, Pair.makePair(Integer.valueOf(intValue), num));
                }
            }
        }
        HashMap newHashMap2 = Maps.newHashMap();
        for (Instance instance2 : list) {
            int source = instance2.source();
            int target = instance2.target();
            Pair makePair = Pair.makePair(Integer.valueOf(source), Integer.valueOf(target));
            HashMap newHashMap3 = Maps.newHashMap();
            Map map2 = (Map) newHashMap.get(Integer.valueOf(source));
            if (map2 != null) {
                for (Map.Entry entry : map2.entrySet()) {
                    for (Pair pair : (Set) entry.getValue()) {
                        if (source == ((Integer) pair.getLeft()).intValue() || target == ((Integer) pair.getLeft()).intValue()) {
                            MapUtil.addValueToKeySet(newHashMap3, entry.getKey(), pair);
                        }
                    }
                }
            }
            Map map3 = (Map) newHashMap.get(Integer.valueOf(target));
            if (map3 != null) {
                for (Map.Entry entry2 : map3.entrySet()) {
                    for (Pair pair2 : (Set) entry2.getValue()) {
                        if (source == ((Integer) pair2.getLeft()).intValue() || target == ((Integer) pair2.getLeft()).intValue()) {
                            MapUtil.addValueToKeySet(newHashMap3, entry2.getKey(), pair2);
                        }
                    }
                }
            }
            Map<PathType, Integer> map4 = pathCountMap.get(instance2);
            if (map4 != null) {
                Iterator<PathType> it3 = map4.keySet().iterator();
                while (it3.hasNext()) {
                    MapUtil.addValueToKeySet(newHashMap3, it3.next(), makePair);
                }
            }
            newHashMap2.put(instance2, newHashMap3);
        }
        return newHashMap2;
    }

    @VisibleForTesting
    protected void incrementCounts(Map<PathType, Integer> map, PathType pathType, DiscreteDistribution discreteDistribution) {
        for (IdCount idCount : discreteDistribution.getTop(5)) {
            PathType pathType2 = (PathType) this.pathDict.getKey(idCount.id);
            int i = idCount.count;
            if (this.policy == PathTypePolicy.PAIRED_ONLY) {
                i *= i;
            } else if (this.policy == PathTypePolicy.EVERYTHING) {
                i *= i * i;
            }
            incrementCounts(map, pathType, pathType2, i);
        }
    }

    @VisibleForTesting
    protected void incrementCounts(Map<PathType, Integer> map, DiscreteDistribution discreteDistribution, PathType pathType) {
        for (IdCount idCount : discreteDistribution.getTop(5)) {
            PathType pathType2 = (PathType) this.pathDict.getKey(idCount.id);
            int i = idCount.count;
            if (this.policy == PathTypePolicy.PAIRED_ONLY) {
                i *= i;
            } else if (this.policy == PathTypePolicy.EVERYTHING) {
                i *= i * i;
            }
            incrementCounts(map, pathType2, pathType, i);
        }
    }

    @VisibleForTesting
    protected void incrementCounts(Map<PathType, Integer> map, DiscreteDistribution discreteDistribution, DiscreteDistribution discreteDistribution2) {
        for (IdCount idCount : discreteDistribution.getTop(5)) {
            incrementCounts(map, (PathType) this.pathDict.getKey(idCount.id), discreteDistribution2, idCount.count);
        }
    }

    @VisibleForTesting
    protected void incrementCounts(Map<PathType, Integer> map, PathType pathType, DiscreteDistribution discreteDistribution, int i) {
        for (IdCount idCount : discreteDistribution.getTop(5)) {
            incrementCounts(map, pathType, (PathType) this.pathDict.getKey(idCount.id), i * idCount.count);
        }
    }

    @VisibleForTesting
    protected void incrementCounts(Map<PathType, Integer> map, PathType pathType, PathType pathType2, int i) {
        PathType concatenatePathTypes = this.pathTypeFactory.concatenatePathTypes(pathType, pathType2);
        Integer num = map.get(concatenatePathTypes);
        if (num == null) {
            map.put(concatenatePathTypes, Integer.valueOf(i));
        } else {
            map.put(concatenatePathTypes, Integer.valueOf(num.intValue() + i));
        }
    }
}
