package edu.cmu.ml.rtw.pra.config;

import com.google.common.collect.Maps;
import edu.cmu.ml.rtw.pra.experiments.Dataset;
import edu.cmu.ml.rtw.pra.experiments.Outputter;
import edu.cmu.ml.rtw.pra.features.BasicPathTypeFactory;
import edu.cmu.ml.rtw.pra.features.EdgeExcluderFactory;
import edu.cmu.ml.rtw.pra.features.MatrixPathFollowerFactory;
import edu.cmu.ml.rtw.pra.features.MatrixRowPolicy;
import edu.cmu.ml.rtw.pra.features.MostFrequentPathTypeSelector;
import edu.cmu.ml.rtw.pra.features.PathFollowerFactory;
import edu.cmu.ml.rtw.pra.features.PathTypeFactory;
import edu.cmu.ml.rtw.pra.features.PathTypePolicy;
import edu.cmu.ml.rtw.pra.features.PathTypeSelector;
import edu.cmu.ml.rtw.pra.features.RandomWalkPathFollowerFactory;
import edu.cmu.ml.rtw.pra.features.SingleEdgeExcluderFactory;
import edu.cmu.ml.rtw.pra.features.VectorClusteringPathTypeSelector;
import edu.cmu.ml.rtw.pra.features.VectorPathTypeFactory;
import edu.cmu.ml.rtw.users.matt.util.Dictionary;
import edu.cmu.ml.rtw.users.matt.util.Vector;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/config/PraConfig.class */
public class PraConfig {
    public final String graph;
    public final int numShards;
    public final String outputBase;
    public final Dataset allData;
    public final double percentTraining;
    public final Dataset trainingData;
    public final Dataset testingData;
    public final Map<Integer, Integer> relationInverses;
    public final int numIters;
    public final int walksPerSource;
    public final List<Integer> unallowedEdges;
    public final EdgeExcluderFactory edgeExcluderFactory;
    public final int numPaths;
    public final PathTypePolicy pathTypePolicy;
    public final PathTypeFactory pathTypeFactory;
    public final PathTypeSelector pathTypeSelector;
    public final PathFollowerFactory pathFollowerFactory;
    public final int maxMatrixFeatureFanOut;
    public final int walksPerPath;
    public final MatrixRowPolicy acceptPolicy;
    public final Set<Integer> allowedTargets;
    public final boolean normalizeWalkProbabilities;
    public final double l2Weight;
    public final double l1Weight;
    public final boolean binarizeFeatures;
    public final Dictionary nodeDict;
    public final Dictionary edgeDict;
    public final Outputter outputter;

    /* loaded from: input_file:edu/cmu/ml/rtw/pra/config/PraConfig$Builder.class */
    public static class Builder {
        private String graph;
        private int numShards;
        private int numIters;
        private int walksPerSource;
        private int numPaths;
        private int walksPerPath;
        private double l2Weight;
        private double l1Weight;
        private Dataset allData;
        private double percentTraining;
        private Dataset trainingData;
        private Dataset testingData;
        private Set<Integer> allowedTargets;
        private List<Integer> unallowedEdges;
        private Map<Integer, Integer> relationInverses;
        private String outputBase;
        private PathTypePolicy pathTypePolicy = PathTypePolicy.PAIRED_ONLY;
        public PathTypeFactory pathTypeFactory = new BasicPathTypeFactory();
        private PathTypeSelector pathTypeSelector = new MostFrequentPathTypeSelector();
        private PathFollowerFactory pathFollowerFactory = new RandomWalkPathFollowerFactory();
        private int maxMatrixFeatureFanOut = 100;
        private MatrixRowPolicy acceptPolicy = MatrixRowPolicy.ALL_TARGETS;
        private boolean binarizeFeatures = false;
        private boolean normalizeWalkProbabilities = true;
        private EdgeExcluderFactory edgeExcluderFactory = new SingleEdgeExcluderFactory();
        public Dictionary nodeDict = new Dictionary();
        public Dictionary edgeDict = new Dictionary();
        public Outputter outputter = null;

        public Builder() {
        }

        public Builder setGraph(String str) {
            this.graph = str;
            return this;
        }

        public Builder setNumShards(int i) {
            this.numShards = i;
            return this;
        }

        public Builder setNumIters(int i) {
            this.numIters = i;
            return this;
        }

        public Builder setPathFollowerFactory(PathFollowerFactory pathFollowerFactory) {
            this.pathFollowerFactory = pathFollowerFactory;
            return this;
        }

        public Builder setMaxMatrixFeatureFanOut(int i) {
            this.maxMatrixFeatureFanOut = i;
            return this;
        }

        public Builder setWalksPerSource(int i) {
            this.walksPerSource = i;
            return this;
        }

        public Builder setNumPaths(int i) {
            this.numPaths = i;
            return this;
        }

        public Builder setPathTypePolicy(PathTypePolicy pathTypePolicy) {
            this.pathTypePolicy = pathTypePolicy;
            return this;
        }

        public Builder setPathTypeFactory(PathTypeFactory pathTypeFactory) {
            this.pathTypeFactory = pathTypeFactory;
            return this;
        }

        public Builder setPathTypeSelector(PathTypeSelector pathTypeSelector) {
            this.pathTypeSelector = pathTypeSelector;
            return this;
        }

        public Builder setWalksPerPath(int i) {
            this.walksPerPath = i;
            return this;
        }

        public Builder setAcceptPolicy(MatrixRowPolicy matrixRowPolicy) {
            this.acceptPolicy = matrixRowPolicy;
            return this;
        }

        public Builder setL2Weight(double d) {
            this.l2Weight = d;
            return this;
        }

        public Builder setL1Weight(double d) {
            this.l1Weight = d;
            return this;
        }

        public Builder setBinarizeFeatures(boolean z) {
            this.binarizeFeatures = z;
            return this;
        }

        public Builder setAllData(Dataset dataset) {
            this.allData = dataset;
            return this;
        }

        public Builder setPercentTraining(double d) {
            this.percentTraining = d;
            return this;
        }

        public Builder setTrainingData(Dataset dataset) {
            this.trainingData = dataset;
            return this;
        }

        public Builder setTestingData(Dataset dataset) {
            this.testingData = dataset;
            return this;
        }

        public Builder setAllowedTargets(Set<Integer> set) {
            this.allowedTargets = set;
            return this;
        }

        public Builder setNormalizeWalkProbabilities(boolean z) {
            this.normalizeWalkProbabilities = z;
            return this;
        }

        public Builder setUnallowedEdges(List<Integer> list) {
            this.unallowedEdges = list;
            return this;
        }

        public Builder setEdgeExcluderFactory(EdgeExcluderFactory edgeExcluderFactory) {
            this.edgeExcluderFactory = edgeExcluderFactory;
            return this;
        }

        public Builder setRelationInverses(Map<Integer, Integer> map) {
            this.relationInverses = map;
            return this;
        }

        public Builder setOutputBase(String str) {
            this.outputBase = str;
            return this;
        }

        public Builder setNodeDictionary(Dictionary dictionary) {
            this.nodeDict = dictionary;
            return this;
        }

        public Builder setEdgeDictionary(Dictionary dictionary) {
            this.edgeDict = dictionary;
            return this;
        }

        public Builder setOutputter(Outputter outputter) {
            this.outputter = outputter;
            return this;
        }

        public PraConfig build() {
            if (this.outputter == null) {
                this.outputter = new Outputter(this.nodeDict, this.edgeDict);
            }
            return new PraConfig(this);
        }

        public Builder(PraConfig praConfig) {
            setGraph(praConfig.graph);
            setNumShards(praConfig.numShards);
            setNumIters(praConfig.numIters);
            setWalksPerSource(praConfig.walksPerSource);
            setNumPaths(praConfig.numPaths);
            setPathTypePolicy(praConfig.pathTypePolicy);
            setPathTypeFactory(praConfig.pathTypeFactory);
            setPathTypeSelector(praConfig.pathTypeSelector);
            setPathFollowerFactory(praConfig.pathFollowerFactory);
            setMaxMatrixFeatureFanOut(praConfig.maxMatrixFeatureFanOut);
            setWalksPerPath(praConfig.walksPerPath);
            setAcceptPolicy(praConfig.acceptPolicy);
            setL2Weight(praConfig.l2Weight);
            setL1Weight(praConfig.l1Weight);
            setBinarizeFeatures(praConfig.binarizeFeatures);
            setAllData(praConfig.allData);
            setPercentTraining(praConfig.percentTraining);
            setTrainingData(praConfig.trainingData);
            setTestingData(praConfig.testingData);
            setAllowedTargets(praConfig.allowedTargets);
            setNormalizeWalkProbabilities(praConfig.normalizeWalkProbabilities);
            setUnallowedEdges(praConfig.unallowedEdges);
            setEdgeExcluderFactory(praConfig.edgeExcluderFactory);
            setRelationInverses(praConfig.relationInverses);
            setOutputBase(praConfig.outputBase);
            setNodeDictionary(praConfig.nodeDict);
            setEdgeDictionary(praConfig.edgeDict);
            setOutputter(praConfig.outputter);
        }

        public void setFromParamFile(BufferedReader bufferedReader) throws IOException {
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                String[] split = readLine.split("\t");
                String str = split[0];
                String str2 = split[1];
                if (str.equalsIgnoreCase("L1 weight")) {
                    setL1Weight(Double.parseDouble(str2));
                } else if (str.equalsIgnoreCase("L2 weight")) {
                    setL2Weight(Double.parseDouble(str2));
                } else if (str.equalsIgnoreCase("walks per source")) {
                    setWalksPerSource(Integer.parseInt(str2));
                } else if (str.equalsIgnoreCase("walks per path")) {
                    setWalksPerPath(Integer.parseInt(str2));
                } else if (str.equalsIgnoreCase("path finding iterations")) {
                    setNumIters(Integer.parseInt(str2));
                } else if (str.equalsIgnoreCase("number of paths to keep")) {
                    setNumPaths(Integer.parseInt(str2));
                } else if (str.equalsIgnoreCase("binarize features")) {
                    setBinarizeFeatures(Boolean.parseBoolean(str2));
                } else if (str.equalsIgnoreCase("normalize walk probabilities")) {
                    setNormalizeWalkProbabilities(Boolean.parseBoolean(str2));
                } else if (str.equalsIgnoreCase("matrix accept policy")) {
                    setAcceptPolicy(MatrixRowPolicy.parseFromString(str2));
                } else if (str.equalsIgnoreCase("path accept policy")) {
                    setPathTypePolicy(PathTypePolicy.parseFromString(str2));
                } else if (str.equalsIgnoreCase("path type embeddings")) {
                    initializeVectorPathTypeFactory(str2);
                } else if (str.equalsIgnoreCase("path type selector")) {
                    initializePathTypeSelector(str2);
                } else if (str.equalsIgnoreCase("path follower")) {
                    initializePathFollowerFactory(str2);
                } else {
                    if (!str.equalsIgnoreCase("max matrix feature fan out")) {
                        throw new RuntimeException("Unrecognized parameter specification: " + readLine);
                    }
                    setMaxMatrixFeatureFanOut(Integer.parseInt(str2));
                }
            }
        }

        public void initializeVectorPathTypeFactory(String str) throws IOException {
            System.out.println("Initializing vector path type factory");
            String[] split = str.split(",");
            double parseDouble = Double.parseDouble(split[0]);
            double parseDouble2 = Double.parseDouble(split[1]);
            HashMap newHashMap = Maps.newHashMap();
            for (int i = 2; i < split.length; i++) {
                String str2 = split[i];
                System.out.println("Embeddings file: " + str2);
                BufferedReader bufferedReader = new BufferedReader(new FileReader(str2));
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine != null) {
                        String[] split2 = readLine.split("\t");
                        int index = this.edgeDict.getIndex(split2[0]);
                        double[] dArr = new double[split2.length - 1];
                        for (int i2 = 0; i2 < dArr.length; i2++) {
                            dArr[i2] = Double.parseDouble(split2[i2 + 1]);
                        }
                        newHashMap.put(Integer.valueOf(index), new Vector(dArr));
                    }
                }
            }
            setPathTypeFactory(new VectorPathTypeFactory(this.edgeDict, newHashMap, parseDouble, parseDouble2));
        }

        public void initializePathTypeSelector(String str) {
            if (!str.startsWith("VectorClusteringPathTypeSelector")) {
                throw new RuntimeException("Unrecognized path type selector parameter!");
            }
            System.out.println("Using VectorClusteringPathTypeSelector");
            setPathTypeSelector(new VectorClusteringPathTypeSelector((VectorPathTypeFactory) this.pathTypeFactory, Double.parseDouble(str.split(",")[1])));
        }

        public void initializePathFollowerFactory(String str) {
            if (str.equalsIgnoreCase("random walks")) {
                setPathFollowerFactory(new RandomWalkPathFollowerFactory());
            } else if (str.equalsIgnoreCase("matrix multiplication")) {
                setPathFollowerFactory(new MatrixPathFollowerFactory());
            }
        }
    }

    private PraConfig(Builder builder) {
        this.graph = builder.graph;
        this.numShards = builder.numShards;
        this.numIters = builder.numIters;
        this.walksPerSource = builder.walksPerSource;
        this.numPaths = builder.numPaths;
        this.pathTypePolicy = builder.pathTypePolicy;
        this.pathTypeFactory = builder.pathTypeFactory;
        this.pathTypeSelector = builder.pathTypeSelector;
        this.pathFollowerFactory = builder.pathFollowerFactory;
        this.maxMatrixFeatureFanOut = builder.maxMatrixFeatureFanOut;
        this.walksPerPath = builder.walksPerPath;
        this.acceptPolicy = builder.acceptPolicy;
        this.l2Weight = builder.l2Weight;
        this.l1Weight = builder.l1Weight;
        this.binarizeFeatures = builder.binarizeFeatures;
        this.allData = builder.allData;
        this.percentTraining = builder.percentTraining;
        this.trainingData = builder.trainingData;
        this.testingData = builder.testingData;
        this.allowedTargets = builder.allowedTargets;
        this.normalizeWalkProbabilities = builder.normalizeWalkProbabilities;
        this.unallowedEdges = builder.unallowedEdges;
        this.edgeExcluderFactory = builder.edgeExcluderFactory;
        this.relationInverses = builder.relationInverses;
        this.outputBase = builder.outputBase;
        this.nodeDict = builder.nodeDict;
        this.edgeDict = builder.edgeDict;
        this.outputter = builder.outputter;
    }
}
