package edu.cmu.ml.rtw.pra.on_demand;

import com.google.common.collect.Lists;
import edu.cmu.ml.rtw.pra.config.PraConfig;
import edu.cmu.ml.rtw.pra.experiments.Dataset;
import edu.cmu.ml.rtw.pra.features.FeatureGenerator;
import edu.cmu.ml.rtw.pra.features.FeatureMatrix;
import edu.cmu.ml.rtw.pra.features.MatrixRow;
import edu.cmu.ml.rtw.pra.features.MatrixRowPolicy;
import edu.cmu.ml.rtw.pra.features.PathType;
import edu.cmu.ml.rtw.pra.models.PraModel;
import edu.cmu.ml.rtw.users.matt.util.Dictionary;
import edu.cmu.ml.rtw.users.matt.util.FileUtil;
import edu.cmu.ml.rtw.users.matt.util.Index;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import edu.cmu.ml.rtw.users.matt.util.PairComparator;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/on_demand/OnlinePraPredictor.class */
public class OnlinePraPredictor {
    private final List<List<Pair<PathType, Double>>> models;
    private final List<String> relationNames;
    private final List<List<PathType>> pathTypes;
    private final List<List<Double>> weights;
    private final List<Set<Integer>> allowedTargets;
    private final Dictionary nodeDict;
    private final Dictionary edgeDict;
    private final PraConfig config;

    public OnlinePraPredictor(String str, String str2, String str3, String str4, int i, int i2, String str5, String str6, String str7) {
        this((List<String>) Arrays.asList(str), str2, str3, str4, i, i2, (List<String>) Arrays.asList(str5), str6, str7);
    }

    public OnlinePraPredictor(List<String> list, String str, String str2, String str3, int i, int i2, List<String> list2, String str4, String str5) {
        PraConfig.Builder outputBase = new PraConfig.Builder().setGraph(str3).setNumShards(i).setWalksPerPath(i2).setAcceptPolicy(MatrixRowPolicy.EVERYTHING).setOutputBase(str5);
        if (str4 != null) {
            try {
                outputBase.initializeVectorPathTypeFactory(str4);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        this.config = outputBase.build();
        this.models = new ArrayList();
        this.relationNames = new ArrayList();
        this.pathTypes = new ArrayList();
        this.weights = new ArrayList();
        this.allowedTargets = new ArrayList();
        this.nodeDict = new Dictionary();
        this.nodeDict.setFromFile(new File(str));
        this.edgeDict = new Dictionary();
        this.edgeDict.setFromFile(new File(str2));
        for (int i3 = 0; i3 < list.size(); i3++) {
            try {
                List<Pair<PathType, Double>> readWeightsFromFile = new PraModel(this.config).readWeightsFromFile(list.get(i3));
                this.models.add(readWeightsFromFile);
                String[] split = list.get(i3).split("/");
                this.relationNames.add(split[split.length - 2]);
                try {
                    this.allowedTargets.add(new FileUtil().readIntegerSetFromFile(list2.get(i3), this.nodeDict));
                    ArrayList arrayList = new ArrayList();
                    this.pathTypes.add(arrayList);
                    ArrayList arrayList2 = new ArrayList();
                    this.weights.add(arrayList2);
                    for (int i4 = 0; i4 < readWeightsFromFile.size(); i4++) {
                        arrayList.add(readWeightsFromFile.get(i4).getLeft());
                        arrayList2.add(readWeightsFromFile.get(i4).getRight());
                    }
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
            } catch (IOException e3) {
                throw new RuntimeException(e3);
            }
        }
    }

    public List<PraPrediction> predictTargets(String str) {
        return predictTargets(str, (List<String>) null);
    }

    public List<PraPrediction> predictTargets(String str, List<String> list) {
        List<PraPrediction> list2 = predictTargets(Arrays.asList(str), list).get(str);
        if (list2 == null) {
            list2 = new ArrayList();
        }
        return list2;
    }

    public Map<String, List<PraPrediction>> predictTargets(List<String> list) {
        return predictTargets(list, (List<String>) null);
    }

    public Map<String, List<PraPrediction>> predictTargets(List<String> list, List<String> list2) {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            if (this.nodeDict.hasString(str)) {
                arrayList.add(Integer.valueOf(this.nodeDict.getIndex(str)));
            }
        }
        if (arrayList.size() == 0) {
            return hashMap;
        }
        Dataset build = new Dataset.Builder().setPositiveSources(arrayList).build();
        String str2 = this.config.outputBase != null ? this.config.outputBase + "batch_prediction_matrix.tsv" : null;
        ArrayList arrayList2 = new ArrayList();
        if (list2 == null) {
            for (int i = 0; i < this.models.size(); i++) {
                arrayList2.add(Integer.valueOf(i));
            }
        } else {
            for (String str3 : list2) {
                int indexOf = this.relationNames.indexOf(str3);
                if (indexOf == -1) {
                    System.out.println("NOTE: You requested predictions for a relation that isn't loaded: " + str3);
                } else {
                    arrayList2.add(Integer.valueOf(indexOf));
                }
            }
        }
        ArrayList arrayList3 = new ArrayList();
        Index<PathType> index = new Index<>(this.config.pathTypeFactory);
        Map<Integer, Map<Integer, Integer>> createPathTranslationMaps = createPathTranslationMaps(index, arrayList2);
        for (int i2 = 1; i2 < index.getNextIndex(); i2++) {
            arrayList3.add(index.getKey(i2));
        }
        FeatureMatrix computeFeatureValues = new FeatureGenerator(this.config).computeFeatureValues(arrayList3, build, str2);
        Iterator<Integer> it = arrayList2.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            for (MatrixRow matrixRow : computeFeatureValues.getRows()) {
                if (this.allowedTargets.get(intValue).contains(Integer.valueOf(matrixRow.targetNode))) {
                    String string = this.nodeDict.getString(matrixRow.sourceNode);
                    List list3 = (List) hashMap.get(string);
                    if (list3 == null) {
                        list3 = new ArrayList();
                        hashMap.put(string, list3);
                    }
                    MatrixRow translateRow = translateRow(matrixRow, createPathTranslationMaps.get(Integer.valueOf(intValue)));
                    if (translateRow != null) {
                        list3.add(getPredictionForRow(translateRow, intValue));
                    }
                }
            }
        }
        Iterator it2 = hashMap.values().iterator();
        while (it2.hasNext()) {
            Collections.sort((List) it2.next());
        }
        return hashMap;
    }

    public void testPredictTargets(List<String> list, List<String> list2) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            if (this.nodeDict.hasString(str)) {
                arrayList.add(Integer.valueOf(this.nodeDict.getIndex(str)));
            }
        }
    }

    public PraPrediction getPredictionForRow(MatrixRow matrixRow, int i) {
        double classifyMatrixRow = new PraModel(null).classifyMatrixRow(matrixRow, this.weights.get(i));
        String provenance = getProvenance(matrixRow, this.models.get(i), this.edgeDict);
        return new PraPrediction(this.relationNames.get(i), this.nodeDict.getString(matrixRow.sourceNode), this.nodeDict.getString(matrixRow.targetNode), provenance, classifyMatrixRow);
    }

    public Map<Integer, Map<Integer, Integer>> createPathTranslationMaps(Index<PathType> index, List<Integer> list) {
        HashMap hashMap = new HashMap();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            List<PathType> list2 = this.pathTypes.get(intValue);
            HashMap hashMap2 = new HashMap();
            hashMap.put(Integer.valueOf(intValue), hashMap2);
            for (int i = 0; i < list2.size(); i++) {
                hashMap2.put(Integer.valueOf(index.getIndex(list2.get(i)) - 1), Integer.valueOf(i));
            }
        }
        return hashMap;
    }

    public static MatrixRow translateRow(MatrixRow matrixRow, Map<Integer, Integer> map) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < matrixRow.columns; i++) {
            Integer num = map.get(Integer.valueOf(matrixRow.pathTypes[i]));
            if (num != null) {
                arrayList.add(new Pair(num, Double.valueOf(matrixRow.values[i])));
            }
        }
        if (arrayList.size() == 0) {
            return null;
        }
        int[] iArr = new int[arrayList.size()];
        double[] dArr = new double[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            iArr[i2] = ((Integer) ((Pair) arrayList.get(i2)).getLeft()).intValue();
            dArr[i2] = ((Double) ((Pair) arrayList.get(i2)).getRight()).doubleValue();
        }
        return new MatrixRow(matrixRow.sourceNode, matrixRow.targetNode, iArr, dArr);
    }

    public static String getProvenance(MatrixRow matrixRow, List<Pair<PathType, Double>> list, Dictionary dictionary) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < matrixRow.columns; i++) {
            Pair<PathType, Double> pair = list.get(matrixRow.pathTypes[i]);
            String encodeAsHumanReadableString = pair.getLeft().encodeAsHumanReadableString(dictionary);
            double doubleValue = pair.getRight().doubleValue();
            double d = matrixRow.values[i] * doubleValue;
            arrayList.add(new Pair(String.format("%s\t%.3f (%.2f * %.2f)", encodeAsHumanReadableString, Double.valueOf(d), Double.valueOf(matrixRow.values[i]), Double.valueOf(doubleValue)), Double.valueOf(Math.abs(d))));
        }
        Collections.sort(arrayList, PairComparator.negativeRight());
        String str = "";
        for (int i2 = 0; i2 < 10 && i2 < arrayList.size(); i2++) {
            str = str + ((String) ((Pair) arrayList.get(i2)).getLeft());
            if (i2 < 10 - 1 && i2 != arrayList.size() - 1) {
                str = str + "\t";
            }
        }
        return str;
    }

    public static void main(String[] strArr) throws IOException {
        ArrayList<Pair> newArrayList = Lists.newArrayList();
        newArrayList.add(new Pair("concept:citylocatedinstate", "concept:stateorprovince"));
        newArrayList.add(new Pair("concept:cityparks", "concept:park"));
        newArrayList.add(new Pair("concept:citytelevisionstation", "concept:televisionstation"));
        newArrayList.add(new Pair("concept:cityliesonriver", "concept:river"));
        String str = "/home/mg1/pra/kod_models/results/";
        ArrayList arrayList = new ArrayList();
        String str2 = "/home/mg1/pra/kod_models/category_instances/";
        ArrayList arrayList2 = new ArrayList();
        ArrayList<String> arrayList3 = new ArrayList();
        for (Pair pair : newArrayList) {
            arrayList.add(str + ((String) pair.getLeft()) + "/weights.tsv");
            arrayList3.add(pair.getLeft());
            arrayList2.add(str2 + ((String) pair.getRight()));
        }
        String str3 = "/home/mg1/pra/kod_models/node_dict.tsv";
        String str4 = "/home/mg1/pra/kod_models/edge_dict.tsv";
        String str5 = "/home/mg1/pra/kod_models/graph_chi/edges.tsv";
        String str6 = null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader("/home/mg1/pra/kod_models/factory_description.txt"));
            str6 = bufferedReader.readLine();
            bufferedReader.close();
        } catch (IOException e) {
        }
        try {
            BufferedReader bufferedReader2 = new BufferedReader(new FileReader("/home/mg1/pra/kod_models/num_shards.tsv"));
            int parseInt = Integer.parseInt(bufferedReader2.readLine());
            bufferedReader2.close();
            new File("test_pra_output/").mkdirs();
            List<PraPrediction> predictTargets = new OnlinePraPredictor(arrayList, str3, str4, str5, parseInt, 50, arrayList2, str6, "test_pra_output/").predictTargets("concept:city:pittsburgh", arrayList3);
            System.out.println(predictTargets.size() + " predictions found");
            for (String str7 : arrayList3) {
                System.out.println(str7 + ":");
                int i = 0;
                for (PraPrediction praPrediction : predictTargets) {
                    if (praPrediction.relation.equals(str7)) {
                        i++;
                        if (i > 5) {
                            break;
                        }
                        System.out.print(praPrediction.targetNode + " ");
                        System.out.print(praPrediction.score + " ");
                        System.out.println(praPrediction.provenance);
                    }
                }
                System.out.println();
            }
            System.exit(0);
        } catch (IOException e2) {
            throw new RuntimeException("Missing num_shards.tsv file!", e2);
        }
    }
}
