package edu.cmu.ml.rtw.pra.features;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import edu.cmu.ml.rtw.pra.config.PraConfig;
import edu.cmu.ml.rtw.pra.experiments.Dataset;
import edu.cmu.ml.rtw.users.matt.util.CollectionsUtil;
import edu.cmu.ml.rtw.users.matt.util.MapUtil;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/features/FeatureGenerator.class */
public class FeatureGenerator {
    private PraConfig config;

    public FeatureGenerator(PraConfig praConfig) {
        this.config = praConfig;
    }

    public List<PathType> selectPathFeatures(Dataset dataset) {
        PathFinder pathFinder = new PathFinder(this.config.graph, this.config.numShards, dataset.getAllSources(), dataset.getAllTargets(), this.config.edgeExcluderFactory.create(createEdgesToExclude(dataset)), this.config.walksPerSource, this.config.pathTypePolicy, this.config.pathTypeFactory);
        pathFinder.execute(this.config.numIters);
        try {
            Thread.sleep(500L);
            Map<PathType, Integer> pathCounts = pathFinder.getPathCounts();
            pathFinder.shutDown();
            Map<PathType, Integer> collapseInverses = collapseInverses(pathCounts, this.config.relationInverses);
            this.config.outputter.outputPathCounts(this.config.outputBase, "found_path_counts.tsv", collapseInverses);
            List<PathType> selectPathTypes = this.config.pathTypeSelector.selectPathTypes(collapseInverses, this.config.numPaths);
            this.config.outputter.outputPaths(this.config.outputBase, "kept_paths.tsv", selectPathTypes);
            return selectPathTypes;
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public Map<Pair<Integer, Integer>, Map<PathType, Integer>> findConnectingPaths(Dataset dataset) {
        PathFinder pathFinder = new PathFinder(this.config.graph, this.config.numShards, dataset.getAllSources(), dataset.getAllTargets(), this.config.edgeExcluderFactory.create(createEdgesToExclude(dataset)), this.config.walksPerSource, this.config.pathTypePolicy, this.config.pathTypeFactory);
        pathFinder.execute(this.config.numIters);
        try {
            Thread.sleep(500L);
            Map<Pair<Integer, Integer>, Map<PathType, Integer>> pathCountMap = pathFinder.getPathCountMap();
            pathFinder.shutDown();
            Map<Pair<Integer, Integer>, Map<PathType, Integer>> collapseInversesInCountMap = collapseInversesInCountMap(pathCountMap, this.config.relationInverses);
            this.config.outputter.outputPathCountMap(this.config.outputBase, "path_count_map.tsv", collapseInversesInCountMap, dataset);
            return collapseInversesInCountMap;
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public FeatureMatrix computeFeatureValues(List<PathType> list, Dataset dataset, String str) {
        PathFollower create = this.config.pathFollowerFactory.create(list, this.config, dataset, createEdgesToExclude(dataset));
        create.execute();
        if (create.usesGraphChi()) {
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        FeatureMatrix featureMatrix = create.getFeatureMatrix();
        create.shutDown();
        if (str != null) {
            this.config.outputter.outputFeatureMatrix(str, featureMatrix, list);
        }
        return featureMatrix;
    }

    @VisibleForTesting
    protected Map<PathType, Integer> collapseInverses(Map<PathType, Integer> map, Map<Integer, Integer> map2) {
        HashMap hashMap = new HashMap();
        for (PathType pathType : map.keySet()) {
            MapUtil.incrementCount(hashMap, this.config.pathTypeFactory.collapseEdgeInverses(pathType, map2), map.get(pathType).intValue());
        }
        return hashMap;
    }

    @VisibleForTesting
    protected Map<Pair<Integer, Integer>, Map<PathType, Integer>> collapseInversesInCountMap(Map<Pair<Integer, Integer>, Map<PathType, Integer>> map, Map<Integer, Integer> map2) {
        HashMap newHashMap = Maps.newHashMap();
        for (Pair<Integer, Integer> pair : map.keySet()) {
            Map<PathType, Integer> map3 = map.get(pair);
            HashMap newHashMap2 = Maps.newHashMap();
            for (PathType pathType : map3.keySet()) {
                MapUtil.incrementCount(newHashMap2, this.config.pathTypeFactory.collapseEdgeInverses(pathType, map2), map3.get(pathType).intValue());
            }
            newHashMap.put(pair, newHashMap2);
        }
        return newHashMap;
    }

    @VisibleForTesting
    protected List<Pair<Pair<Integer, Integer>, Integer>> createEdgesToExclude(Dataset dataset) {
        ArrayList arrayList = new ArrayList();
        if (dataset == null) {
            return arrayList;
        }
        List<Integer> allSources = dataset.getAllSources();
        List<Integer> allTargets = dataset.getAllTargets();
        if (allSources.size() == 0 || allTargets.size() == 0) {
            return arrayList;
        }
        for (Pair pair : CollectionsUtil.zipLists(dataset.getAllSources(), dataset.getAllTargets())) {
            Iterator<Integer> it = this.config.unallowedEdges.iterator();
            while (it.hasNext()) {
                arrayList.add(new Pair(pair, Integer.valueOf(it.next().intValue())));
            }
        }
        return arrayList;
    }
}
