package edu.cmu.ml.rtw.pra.models;

import cc.mallet.pipe.Noop;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import edu.cmu.ml.rtw.pra.config.PraConfig;
import edu.cmu.ml.rtw.pra.experiments.Dataset;
import edu.cmu.ml.rtw.pra.features.FeatureMatrix;
import edu.cmu.ml.rtw.pra.features.MatrixRow;
import edu.cmu.ml.rtw.pra.features.PathType;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/models/PraModel.class */
public class PraModel {
    private PraConfig config;
    private static Logger logger = Logger.getLogger("pra-model");

    public PraModel(PraConfig praConfig) {
        this.config = praConfig;
    }

    public List<Double> learnFeatureWeights(FeatureMatrix featureMatrix, Dataset dataset, List<PathType> list) {
        logger.info("Learning feature weights");
        logger.info("Prepping training data");
        Set<String> positiveInstancesAsStrings = dataset.getPositiveInstancesAsStrings();
        Set<String> negativeInstancesAsStrings = dataset.getNegativeInstancesAsStrings();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (MatrixRow matrixRow : featureMatrix.getRows()) {
            String str = matrixRow.sourceNode + " " + matrixRow.targetNode;
            if (positiveInstancesAsStrings.contains(str)) {
                arrayList.add(matrixRow);
            } else if (negativeInstancesAsStrings.contains(str)) {
                arrayList2.add(matrixRow);
            } else {
                arrayList3.add(matrixRow);
            }
        }
        FeatureMatrix featureMatrix2 = new FeatureMatrix(arrayList);
        FeatureMatrix featureMatrix3 = new FeatureMatrix(arrayList2);
        FeatureMatrix featureMatrix4 = new FeatureMatrix(arrayList3);
        if (this.config.outputBase != null) {
            String str2 = this.config.outputBase;
            this.config.outputter.outputFeatureMatrix(str2 + "positive_matrix.tsv", featureMatrix2, list);
            this.config.outputter.outputFeatureMatrix(str2 + "negative_matrix.tsv", featureMatrix3, list);
            this.config.outputter.outputFeatureMatrix(str2 + "unseen_matrix.tsv", featureMatrix4, list);
        }
        InstanceList instanceList = new InstanceList(new Noop());
        Alphabet alphabet = new Alphabet(list.toArray());
        int i = 0;
        for (MatrixRow matrixRow2 : featureMatrix2.getRows()) {
            i += matrixRow2.columns;
            instanceList.addThruPipe(matrixRowToInstance(matrixRow2, alphabet, true));
        }
        addNegativeEvidence(featureMatrix2.size(), i, featureMatrix3, featureMatrix4, instanceList, alphabet, this.config);
        MalletLogisticRegression malletLogisticRegression = new MalletLogisticRegression(alphabet);
        if (this.config.l2Weight != 0.0d) {
            logger.info("Setting L2 weight to " + this.config.l2Weight);
            malletLogisticRegression.setL2wt(this.config.l2Weight);
        }
        if (this.config.l1Weight != 0.0d) {
            logger.info("Setting L1 weight to " + this.config.l1Weight);
            malletLogisticRegression.setL1wt(this.config.l1Weight);
        }
        logger.info("Training the classifier");
        malletLogisticRegression.train(instanceList);
        int[] sparseFeatures = malletLogisticRegression.getSparseFeatures();
        double[] sparseParams = malletLogisticRegression.getSparseParams();
        malletLogisticRegression.getBias();
        ArrayList arrayList4 = new ArrayList();
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            if (i2 >= sparseFeatures.length) {
                arrayList4.add(Double.valueOf(0.0d));
            } else if (sparseFeatures[i2] > i3) {
                arrayList4.add(Double.valueOf(0.0d));
            } else if (sparseFeatures[i2] == i3) {
                arrayList4.add(Double.valueOf(sparseParams[i2]));
                i2++;
            }
        }
        logger.info("Outputting feature weights");
        if (this.config.outputBase != null) {
            this.config.outputter.outputWeights(this.config.outputBase + "weights.tsv", arrayList4, list);
        }
        return arrayList4;
    }

    public void addNegativeEvidence(int i, int i2, FeatureMatrix featureMatrix, FeatureMatrix featureMatrix2, InstanceList instanceList, Alphabet alphabet, PraConfig praConfig) {
        weightUnseenExamples(i2, featureMatrix, featureMatrix2, instanceList, alphabet, praConfig);
    }

    private void sampleUnseenExamples(int i, FeatureMatrix featureMatrix, FeatureMatrix featureMatrix2, InstanceList instanceList, Alphabet alphabet, PraConfig praConfig) {
        featureMatrix2.shuffle();
        for (int i2 = 0; i2 < i; i2++) {
            instanceList.addThruPipe(matrixRowToInstance(featureMatrix2.getRow(i2), alphabet, false));
        }
    }

    private void weightUnseenExamples(int i, FeatureMatrix featureMatrix, FeatureMatrix featureMatrix2, InstanceList instanceList, Alphabet alphabet, PraConfig praConfig) {
        int i2 = 0;
        for (MatrixRow matrixRow : featureMatrix.getRows()) {
            i2 += matrixRow.columns;
            instanceList.addThruPipe(matrixRowToInstance(matrixRow, alphabet, false));
        }
        logger.info("Number of positive features: " + i);
        logger.info("Number of negative features: " + i2);
        if (i2 < i) {
            logger.info("Using unseen examples to make up the difference");
            int i3 = i - i2;
            int i4 = 0;
            Iterator<MatrixRow> it = featureMatrix2.getRows().iterator();
            while (it.hasNext()) {
                i4 += it.next().columns;
            }
            double d = i3 / i4;
            logger.info("Unseen weight: " + d);
            Iterator<MatrixRow> it2 = featureMatrix2.getRows().iterator();
            while (it2.hasNext()) {
                Instance matrixRowToInstance = matrixRowToInstance(it2.next(), alphabet, false);
                instanceList.addThruPipe(matrixRowToInstance);
                instanceList.setInstanceWeight(matrixRowToInstance, d);
            }
        }
    }

    public Map<Integer, List<Pair<Integer, Double>>> classifyInstances(FeatureMatrix featureMatrix, List<Double> list) {
        HashMap hashMap = new HashMap();
        for (MatrixRow matrixRow : featureMatrix.getRows()) {
            double classifyMatrixRow = classifyMatrixRow(matrixRow, list);
            List list2 = (List) hashMap.get(Integer.valueOf(matrixRow.sourceNode));
            if (list2 == null) {
                list2 = new ArrayList();
                hashMap.put(Integer.valueOf(matrixRow.sourceNode), list2);
            }
            list2.add(new Pair(Integer.valueOf(matrixRow.targetNode), Double.valueOf(classifyMatrixRow)));
        }
        return hashMap;
    }

    public List<Pair<PathType, Double>> readWeightsFromFile(String str) throws IOException {
        ArrayList arrayList = new ArrayList();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return arrayList;
            }
            String[] split = readLine.split("\t");
            arrayList.add(new Pair(this.config.pathTypeFactory.fromString(split[0]), Double.valueOf(Double.parseDouble(split[1]))));
        }
    }

    public double classifyMatrixRow(MatrixRow matrixRow, List<Double> list) {
        double d = 0.0d;
        for (int i = 0; i < matrixRow.columns; i++) {
            d += matrixRow.values[i] * list.get(matrixRow.pathTypes[i]).doubleValue();
        }
        return d;
    }

    public Instance matrixRowToInstance(MatrixRow matrixRow, Alphabet alphabet, boolean z) {
        double d = z ? 1.0d : 0.0d;
        double[] dArr = (double[]) matrixRow.values.clone();
        if (this.config.binarizeFeatures) {
            for (int i = 0; i < dArr.length; i++) {
                if (dArr[i] > 0.0d) {
                    dArr[i] = 1.0d;
                }
            }
        }
        return new Instance(new FeatureVector(alphabet, matrixRow.pathTypes, dArr), Double.valueOf(d), matrixRow.sourceNode + " " + matrixRow.targetNode, (Object) null);
    }
}
