package edu.cmu.ml.rtw.pra.features;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import edu.cmu.ml.rtw.pra.features.VectorPathTypeFactory;
import edu.cmu.ml.rtw.users.matt.util.MapUtil;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import edu.cmu.ml.rtw.users.matt.util.Vector;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/features/VectorClusteringPathTypeSelector.class */
public class VectorClusteringPathTypeSelector implements PathTypeSelector {
    private final VectorPathTypeFactory factory;
    private final double similarityThreshold;
    private int nextClusterIndex = 1;
    public static final String CLUSTER_SEPARATOR = " C+C ";

    public VectorClusteringPathTypeSelector(VectorPathTypeFactory vectorPathTypeFactory, double d) {
        this.factory = vectorPathTypeFactory;
        this.similarityThreshold = d;
    }

    @Override // edu.cmu.ml.rtw.pra.features.PathTypeSelector
    public List<PathType> selectPathTypes(Map<PathType, Integer> map, int i) {
        HashMap newHashMap = Maps.newHashMap();
        Iterator<List<Pair<PathType, Integer>>> it = groupBySignature(map).iterator();
        while (it.hasNext()) {
            for (Pair<PathType, Integer> pair : clusterPathTypes(it.next())) {
                newHashMap.put(pair.getLeft(), pair.getRight());
            }
        }
        return MapUtil.getTopKeys(newHashMap, i);
    }

    @VisibleForTesting
    protected Collection<List<Pair<PathType, Integer>>> groupBySignature(Map<PathType, Integer> map) {
        HashMap newHashMap = Maps.newHashMap();
        for (Map.Entry<PathType, Integer> entry : map.entrySet()) {
            MapUtil.addValueToKeyList(newHashMap, getSignature((VectorPathTypeFactory.VectorPathType) entry.getKey()), new Pair(entry.getKey(), entry.getValue()));
        }
        return newHashMap.values();
    }

    @VisibleForTesting
    protected String getSignature(VectorPathTypeFactory.VectorPathType vectorPathType) {
        String str = "-";
        for (int i = 0; i < vectorPathType.getEdgeTypes().length; i++) {
            if (vectorPathType.getReverse()[i]) {
                str = str + "_";
            }
            int i2 = vectorPathType.getEdgeTypes()[i];
            str = this.factory.getEmbedding(i2) != null ? str + "v-" : str + i2 + "-";
        }
        return str;
    }

    @VisibleForTesting
    protected List<Pair<PathType, Integer>> clusterPathTypes(List<Pair<PathType, Integer>> list) {
        ArrayList newArrayList = Lists.newArrayList(list);
        while (true) {
            double d = this.similarityThreshold;
            Pair pair = null;
            for (int i = 0; i < newArrayList.size(); i++) {
                for (int i2 = i + 1; i2 < newArrayList.size(); i2++) {
                    double dotProduct = getVectorFromPathType((PathType) ((Pair) newArrayList.get(i)).getLeft()).dotProduct(getVectorFromPathType((PathType) ((Pair) newArrayList.get(i2)).getLeft()));
                    if (dotProduct > d) {
                        d = dotProduct;
                        pair = new Pair(newArrayList.get(i), newArrayList.get(i2));
                    }
                }
            }
            if (pair == null) {
                return newArrayList;
            }
            newArrayList.remove(pair.getLeft());
            newArrayList.remove(pair.getRight());
            newArrayList.add(combinePathTypes((Pair) pair.getLeft(), (Pair) pair.getRight()));
        }
    }

    @VisibleForTesting
    protected Vector getVectorFromPathType(PathType pathType) {
        VectorPathTypeFactory.VectorPathType vectorPathType = (VectorPathTypeFactory.VectorPathType) pathType;
        Vector vector = null;
        for (int i = 0; i < vectorPathType.getEdgeTypes().length; i++) {
            Vector embedding = this.factory.getEmbedding(vectorPathType.getEdgeTypes()[i]);
            if (embedding != null) {
                vector = vector == null ? embedding : vector.concatenate(embedding);
            }
        }
        return vector;
    }

    @VisibleForTesting
    protected Pair<PathType, Integer> combinePathTypes(Pair<PathType, Integer> pair, Pair<PathType, Integer> pair2) {
        VectorPathTypeFactory.VectorPathType vectorPathType = (VectorPathTypeFactory.VectorPathType) pair.getLeft();
        VectorPathTypeFactory.VectorPathType vectorPathType2 = (VectorPathTypeFactory.VectorPathType) pair2.getLeft();
        int length = vectorPathType.getEdgeTypes().length;
        int[] iArr = new int[length];
        boolean[] zArr = new boolean[length];
        for (int i = 0; i < length; i++) {
            zArr[i] = vectorPathType.getReverse()[i];
            Vector embedding = this.factory.getEmbedding(vectorPathType.getEdgeTypes()[i]);
            if (embedding == null) {
                iArr[i] = vectorPathType.getEdgeTypes()[i];
            } else {
                int index = this.factory.getEdgeDict().getIndex(this.factory.getEdgeDict().getString(vectorPathType.getEdgeTypes()[i]) + CLUSTER_SEPARATOR + this.factory.getEdgeDict().getString(vectorPathType2.getEdgeTypes()[i]));
                iArr[i] = index;
                Vector add = embedding.multiply(pair.getRight().intValue()).add(this.factory.getEmbedding(vectorPathType2.getEdgeTypes()[i]).multiply(pair2.getRight().intValue()));
                add.normalize();
                this.factory.getEmbeddingsMap().put(Integer.valueOf(index), add);
            }
        }
        this.factory.initializeEmbeddings();
        return new Pair<>(this.factory.newInstance(iArr, zArr), Integer.valueOf(pair.getRight().intValue() + pair2.getRight().intValue()));
    }
}
