package edu.cmu.ml.rtw.pra.experiments;

import com.google.common.annotations.VisibleForTesting;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.ml.rtw.pra.config.PraConfig;
import edu.cmu.ml.rtw.users.matt.util.Dictionary;
import edu.cmu.ml.rtw.users.matt.util.FileUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/experiments/KbPraDriver.class */
public class KbPraDriver {
    private final FileUtil fileUtil;
    private static final Logger logger = ChiLogger.getLogger("kb-pra-driver");

    @VisibleForTesting
    protected static KbPraDriver driver = new KbPraDriver();

    public KbPraDriver() {
        this(new FileUtil());
    }

    @VisibleForTesting
    protected KbPraDriver(FileUtil fileUtil) {
        this.fileUtil = fileUtil;
    }

    public void runPra(String str, String str2, String str3, String str4, String str5) throws IOException, InterruptedException {
        String addDirectorySeparatorIfNecessary = this.fileUtil.addDirectorySeparatorIfNecessary(str5);
        String addDirectorySeparatorIfNecessary2 = this.fileUtil.addDirectorySeparatorIfNecessary(str);
        String addDirectorySeparatorIfNecessary3 = this.fileUtil.addDirectorySeparatorIfNecessary(str2);
        String addDirectorySeparatorIfNecessary4 = this.fileUtil.addDirectorySeparatorIfNecessary(str3);
        this.fileUtil.mkdirOrDie(addDirectorySeparatorIfNecessary);
        long currentTimeMillis = System.currentTimeMillis();
        PraConfig.Builder builder = new PraConfig.Builder();
        parseGraphFiles(addDirectorySeparatorIfNecessary3, builder);
        builder.setFromParamFile(this.fileUtil.getBufferedReader(str4));
        Map<String, String> map = null;
        if (this.fileUtil.fileExists(addDirectorySeparatorIfNecessary2 + "node_names.tsv")) {
            map = this.fileUtil.readMapFromTsvFile(addDirectorySeparatorIfNecessary2 + "node_names.tsv", true);
        }
        builder.setOutputter(new Outputter(builder.nodeDict, builder.edgeDict, map));
        FileWriter fileWriter = this.fileUtil.getFileWriter(addDirectorySeparatorIfNecessary + "settings.txt");
        fileWriter.write("KB used: " + addDirectorySeparatorIfNecessary2 + "\n");
        fileWriter.write("Graph used: " + addDirectorySeparatorIfNecessary3 + "\n");
        fileWriter.write("Splits used: " + addDirectorySeparatorIfNecessary4 + "\n");
        fileWriter.write("Parameter file used: " + str4 + "\n");
        fileWriter.write("Parameters:\n");
        this.fileUtil.copyLines(this.fileUtil.getBufferedReader(str4), fileWriter);
        fileWriter.write("End of parameters\n");
        fileWriter.close();
        PraConfig build = builder.build();
        BufferedReader bufferedReader = this.fileUtil.getBufferedReader(addDirectorySeparatorIfNecessary4 + "relations_to_run.tsv");
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                int currentTimeMillis2 = (int) ((System.currentTimeMillis() - currentTimeMillis) / 1000);
                int i = currentTimeMillis2 / 60;
                int i2 = currentTimeMillis2 - (i * 60);
                FileWriter fileWriter2 = this.fileUtil.getFileWriter(addDirectorySeparatorIfNecessary + "settings.txt", true);
                fileWriter2.write("PRA appears to have finished all relations successfully\n");
                fileWriter2.write("Finished in " + i + " minutes and " + i2 + " seconds\n");
                System.out.println("Took " + i + " minutes and " + i2 + " seconds");
                fileWriter2.close();
                return;
            }
            PraConfig.Builder builder2 = new PraConfig.Builder(build);
            logger.info("\n\n\n\nRunning PRA for relation " + readLine);
            boolean z = false;
            parseKbFiles(addDirectorySeparatorIfNecessary2, readLine, builder2, addDirectorySeparatorIfNecessary, this.fileUtil);
            String addDirectorySeparatorIfNecessary5 = this.fileUtil.addDirectorySeparatorIfNecessary(addDirectorySeparatorIfNecessary + readLine);
            this.fileUtil.mkdirs(addDirectorySeparatorIfNecessary5);
            builder2.setOutputBase(addDirectorySeparatorIfNecessary5);
            initializeSplit(addDirectorySeparatorIfNecessary4, addDirectorySeparatorIfNecessary2, readLine, builder2, new DatasetFactory(), this.fileUtil);
            PraConfig build2 = builder2.build();
            if (build2.allData != null) {
                z = true;
            }
            if (z) {
                new PraTrainAndTester().crossValidate(build2);
            } else {
                new PraTrainAndTester().trainAndTest(build2);
            }
        }
    }

    public boolean initializeSplit(String str, String str2, String str3, PraConfig.Builder builder, DatasetFactory datasetFactory, FileUtil fileUtil) throws IOException {
        String replace = str3.replace("/", "_");
        if (!fileUtil.fileExists(str + replace)) {
            builder.setAllData(datasetFactory.fromFile(str2 + "relations" + File.separator + replace, builder.nodeDict));
            builder.setPercentTraining(fileUtil.readDoubleListFromFile(str + "percent_training.tsv").get(0).doubleValue());
            return true;
        }
        String str4 = str + replace + File.separator + "training.tsv";
        String str5 = str + replace + File.separator + "testing.tsv";
        builder.setTrainingData(datasetFactory.fromFile(str4, builder.nodeDict));
        builder.setTestingData(datasetFactory.fromFile(str5, builder.nodeDict));
        return false;
    }

    public void parseGraphFiles(String str, PraConfig.Builder builder) throws IOException {
        String addDirectorySeparatorIfNecessary = this.fileUtil.addDirectorySeparatorIfNecessary(str);
        builder.setGraph(addDirectorySeparatorIfNecessary + "graph_chi" + File.separator + "edges.tsv");
        System.out.println("Loading node and edge dictionaries from graph directory: " + addDirectorySeparatorIfNecessary);
        builder.setNumShards(Integer.parseInt(new BufferedReader(new FileReader(addDirectorySeparatorIfNecessary + "num_shards.tsv")).readLine()));
        Dictionary dictionary = new Dictionary();
        dictionary.setFromFile(addDirectorySeparatorIfNecessary + "node_dict.tsv");
        builder.setNodeDictionary(dictionary);
        Dictionary dictionary2 = new Dictionary();
        dictionary2.setFromFile(addDirectorySeparatorIfNecessary + "edge_dict.tsv");
        builder.setEdgeDictionary(dictionary2);
    }

    public void parseKbFiles(String str, String str2, PraConfig.Builder builder, String str3, FileUtil fileUtil) throws IOException {
        Map<Integer, Integer> createInverses = createInverses(str + "inverses.tsv", builder.edgeDict);
        builder.setRelationInverses(createInverses);
        Map<String, List<String>> map = null;
        if (fileUtil.fileExists(str + "embeddings.tsv")) {
            map = fileUtil.readMapListFromTsvFile(str + "embeddings.tsv");
        }
        builder.setUnallowedEdges(createUnallowedEdges(str2, createInverses, map, builder.edgeDict));
        if (fileUtil.fileExists(str + "ranges.tsv")) {
            builder.setAllowedTargets(fileUtil.readIntegerSetFromFile(str + "category_instances" + File.separator + fileUtil.readMapFromTsvFile(str + "ranges.tsv").get(str2).replace("/", "_"), builder.nodeDict));
        } else {
            FileWriter fileWriter = fileUtil.getFileWriter(str3 + "settings.txt", true);
            fileWriter.write("No range file found! I hope your accept policy is as you want it...\n");
            System.out.println("No range file found!");
            fileWriter.close();
        }
    }

    public List<Integer> createUnallowedEdges(String str, Map<Integer, Integer> map, Map<String, List<String>> map2, Dictionary dictionary) {
        ArrayList arrayList = new ArrayList();
        int index = dictionary.getIndex(str);
        arrayList.add(Integer.valueOf(index));
        Integer num = map.get(Integer.valueOf(index));
        String str2 = null;
        if (num != null) {
            arrayList.add(num);
            str2 = dictionary.getString(num.intValue());
        }
        if (map2 != null) {
            if (map2.get(str) != null) {
                Iterator<String> it = map2.get(str).iterator();
                while (it.hasNext()) {
                    arrayList.add(Integer.valueOf(dictionary.getIndex(it.next())));
                }
            }
            if (str2 != null && map2.get(str2) != null) {
                Iterator<String> it2 = map2.get(str2).iterator();
                while (it2.hasNext()) {
                    arrayList.add(Integer.valueOf(dictionary.getIndex(it2.next())));
                }
            }
        }
        return arrayList;
    }

    public Map<Integer, Integer> createInverses(String str, Dictionary dictionary) throws IOException {
        HashMap hashMap = new HashMap();
        if (!this.fileUtil.fileExists(str)) {
            return hashMap;
        }
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return hashMap;
            }
            String[] split = readLine.split("\t");
            int index = dictionary.getIndex(split[0]);
            int index2 = dictionary.getIndex(split[1]);
            hashMap.put(Integer.valueOf(index), Integer.valueOf(index2));
            hashMap.put(Integer.valueOf(index2), Integer.valueOf(index));
        }
    }

    public static void main(String[] strArr) throws IOException, InterruptedException {
        try {
            runPra(new PosixParser().parse(createOptionParser(), strArr));
            System.exit(0);
        } catch (ParseException e) {
            printHelp("ParseException while processing arguments");
        }
    }

    public static Options createOptionParser() {
        Options options = new Options();
        options.addOption("k", "kb-files", true, "KB files directory");
        options.addOption("g", "graph-files", true, "Graph files directory");
        options.addOption("s", "split", true, "Split specification directory");
        options.addOption("p", "param-file", true, "parameter file");
        options.addOption("o", "outdir", true, "base directory for output");
        return options;
    }

    private static void printHelp(String str) {
        if (str != null) {
            System.out.println(str);
        }
        new HelpFormatter().printHelp("KbPraDriver", createOptionParser());
    }

    public static void runPra(CommandLine commandLine) throws IOException, InterruptedException {
        String optionValue = commandLine.getOptionValue("outdir");
        String optionValue2 = commandLine.getOptionValue("kb-files");
        String optionValue3 = commandLine.getOptionValue("graph-files");
        String optionValue4 = commandLine.getOptionValue("split");
        String optionValue5 = commandLine.getOptionValue("param-file");
        if (optionValue == null || optionValue2 == null || optionValue3 == null || optionValue4 == null || optionValue5 == null) {
            printHelp("One or more of the parameters was missing");
        } else {
            driver.runPra(optionValue2, optionValue3, optionValue4, optionValue5, optionValue);
        }
    }
}
