package edu.cmu.ml.rtw.pra.experiments;

import com.google.common.annotations.VisibleForTesting;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.ml.rtw.pra.config.PraConfig;
import edu.cmu.ml.rtw.pra.features.FeatureGenerator;
import edu.cmu.ml.rtw.pra.features.PathType;
import edu.cmu.ml.rtw.pra.models.PraModel;
import edu.cmu.ml.rtw.users.matt.util.FileUtil;
import edu.cmu.ml.rtw.users.matt.util.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/experiments/PraTrainAndTester.class */
public class PraTrainAndTester {
    private static Logger logger = ChiLogger.getLogger("pra-driver");
    private final FileUtil fileUtil;

    public PraTrainAndTester() {
        this(new FileUtil());
    }

    @VisibleForTesting
    public PraTrainAndTester(FileUtil fileUtil) {
        this.fileUtil = fileUtil;
    }

    public void crossValidate(PraConfig praConfig) {
        Pair<Dataset, Dataset> splitData = praConfig.allData.splitData(praConfig.percentTraining);
        Dataset left = splitData.getLeft();
        Dataset right = splitData.getRight();
        praConfig.outputter.outputSplitFiles(praConfig.outputBase, left, right);
        PraConfig.Builder builder = new PraConfig.Builder(praConfig);
        builder.setAllData(null);
        builder.setPercentTraining(0.0d);
        builder.setTrainingData(left);
        builder.setTestingData(right);
        trainAndTest(builder.build());
    }

    public void trainAndTest(PraConfig praConfig) {
        testPraModel(praConfig, trainPraModel(praConfig));
    }

    public List<Pair<PathType, Double>> trainPraModel(PraConfig praConfig) {
        FeatureGenerator featureGenerator = new FeatureGenerator(praConfig);
        List<PathType> selectPathFeatures = featureGenerator.selectPathFeatures(praConfig.trainingData);
        List<Double> learnFeatureWeights = new PraModel(praConfig).learnFeatureWeights(featureGenerator.computeFeatureValues(selectPathFeatures, praConfig.trainingData, null), praConfig.trainingData, selectPathFeatures);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < selectPathFeatures.size(); i++) {
            arrayList.add(new Pair(selectPathFeatures.get(i), learnFeatureWeights.get(i)));
        }
        return arrayList;
    }

    public Map<Integer, List<Pair<Integer, Double>>> testPraModel(PraConfig praConfig, List<Pair<PathType, Double>> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).getRight().doubleValue() != 0.0d) {
                arrayList.add(list.get(i).getLeft());
                arrayList2.add(list.get(i).getRight());
            }
        }
        Map<Integer, List<Pair<Integer, Double>>> classifyInstances = new PraModel(praConfig).classifyInstances(new FeatureGenerator(praConfig).computeFeatureValues(arrayList, praConfig.testingData, praConfig.outputBase == null ? null : praConfig.outputBase + "test_matrix.tsv"), arrayList2);
        praConfig.outputter.outputScores(praConfig.outputBase + "scores.tsv", classifyInstances, praConfig);
        return classifyInstances;
    }
}
