package edu.cmu.ml.rtw.pra.models;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.optimize.OrthantWiseLimitedMemoryBFGS;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/ml/rtw/pra/models/MalletLogisticRegression.class */
public class MalletLogisticRegression {
    protected String name;
    private static Logger log;
    private Alphabet alphabet;
    private String predicate;
    private double l1wt;
    private double l2wt;
    private int[] sFeatures;
    private double[] sParams;
    private double sBias;
    private double[] dParams;
    private boolean preferSparse;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/ml/rtw/pra/models/MalletLogisticRegression$LogRegOptimizable.class */
    public class LogRegOptimizable implements Optimizable.ByGradientValue {
        InstanceList trainingList;
        double[] cachedGradient;
        double cachedValue;
        boolean cachedGradientStale = true;
        boolean cachedValueStale = true;
        int numGetValueCalls = 0;
        int numGetValueGradientCalls = 0;
        static final /* synthetic */ boolean $assertionsDisabled;

        public LogRegOptimizable(InstanceList instanceList) {
            this.trainingList = instanceList;
            MalletLogisticRegression.this.setDense();
            this.cachedGradient = new double[MalletLogisticRegression.this.dParams.length];
        }

        public double getValue() {
            double d;
            double d2;
            if (this.cachedValueStale) {
                this.numGetValueCalls++;
                this.cachedValue = 0.0d;
                MatrixOps.setAll(this.cachedGradient, 0.0d);
                this.cachedGradientStale = true;
                for (int i = 0; i < this.trainingList.size(); i++) {
                    Instance instance = (Instance) this.trainingList.get(i);
                    double instanceWeight = this.trainingList.getInstanceWeight(instance);
                    double doubleValue = ((Double) instance.getTarget()).doubleValue();
                    double classify = MalletLogisticRegression.this.classify(instance);
                    if (!$assertionsDisabled && Double.isInfinite(classify)) {
                        throw new AssertionError();
                    }
                    double log = instanceWeight * Math.log(doubleValue > 0.5d ? classify : 1.0d - classify);
                    if (Double.isNaN(log)) {
                        MalletLogisticRegression.log.info("MalletLogisticRegression: NaN - Instance " + instance.getName() + " has instance weight = " + instanceWeight);
                    }
                    if (Double.isInfinite(log)) {
                        MalletLogisticRegression.log.warn("Instance " + instance.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue = log;
                        this.cachedValueStale = false;
                        return log;
                    }
                    FeatureVector featureVector = (FeatureVector) instance.getData();
                    this.cachedValue += log;
                    double d3 = classify * (1.0d - classify);
                    if (doubleValue > 0.5d) {
                        d = 1.0d;
                        d2 = classify;
                    } else {
                        d = 1.0d;
                        d2 = classify - 1.0d;
                    }
                    double d4 = d3 * (d / d2);
                    MatrixOps.rowPlusEquals(this.cachedGradient, MalletLogisticRegression.this.dParams.length - 1, 0, featureVector, instanceWeight * d4);
                    double[] dArr = this.cachedGradient;
                    int length = MalletLogisticRegression.this.dParams.length - 1;
                    dArr[length] = dArr[length] + (instanceWeight * d4);
                }
                double d5 = this.cachedValue;
                double d6 = 0.0d;
                for (int i2 = 0; i2 < MalletLogisticRegression.this.dParams.length - 1; i2++) {
                    d6 += MalletLogisticRegression.this.dParams[i2] * MalletLogisticRegression.this.dParams[i2];
                }
                this.cachedValue -= MalletLogisticRegression.this.l2wt * d6;
                MatrixOps.plusEquals(this.cachedGradient, MalletLogisticRegression.this.dParams, (-2.0d) * MalletLogisticRegression.this.l2wt);
            }
            return this.cachedValue;
        }

        public void getValueGradient(double[] dArr) {
            if (this.cachedGradientStale) {
                this.numGetValueGradientCalls++;
                if (this.cachedValueStale) {
                    getValue();
                }
                MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0d);
                this.cachedGradientStale = false;
            }
            if (!$assertionsDisabled && (dArr == null || dArr.length != MalletLogisticRegression.this.dParams.length)) {
                throw new AssertionError();
            }
            System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
        }

        public int getNumParameters() {
            return MalletLogisticRegression.this.dParams.length;
        }

        public double getParameter(int i) {
            return MalletLogisticRegression.this.dParams[i];
        }

        public void getParameters(double[] dArr) {
            if (dArr == null || dArr.length != MalletLogisticRegression.this.dParams.length) {
                dArr = new double[MalletLogisticRegression.this.dParams.length];
            }
            System.arraycopy(MalletLogisticRegression.this.dParams, 0, dArr, 0, MalletLogisticRegression.this.dParams.length);
        }

        public void setParameter(int i, double d) {
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            MalletLogisticRegression.this.dParams[i] = d;
        }

        public void setParameters(double[] dArr) {
            if (!$assertionsDisabled && dArr == null) {
                throw new AssertionError();
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            if (dArr.length != MalletLogisticRegression.this.dParams.length) {
                MalletLogisticRegression.this.dParams = new double[dArr.length];
            }
            System.arraycopy(dArr, 0, MalletLogisticRegression.this.dParams, 0, dArr.length);
        }

        static {
            $assertionsDisabled = !MalletLogisticRegression.class.desiredAssertionStatus();
        }
    }

    protected void setDense() {
        if (this.dParams != null) {
            return;
        }
        log.debug("Converting to dense vectors...");
        this.dParams = new double[this.alphabet.size() + 1];
        for (int i = 0; i < this.sFeatures.length; i++) {
            this.dParams[this.sFeatures[i]] = this.sParams[i];
        }
        this.dParams[this.alphabet.size()] = this.sBias;
        this.sFeatures = null;
        this.sParams = null;
        this.sBias = Double.NaN;
    }

    protected void setSparse() {
        if (this.sParams != null) {
            return;
        }
        log.debug("Converting to sparse vectors...");
        int i = 0;
        for (int i2 = 0; i2 < this.dParams.length - 1; i2++) {
            if (this.dParams[i2] < -1.0E-8d || this.dParams[i2] > 1.0E-8d) {
                i++;
            }
        }
        this.sFeatures = new int[i];
        this.sParams = new double[i];
        int i3 = 0;
        for (int i4 = 0; i4 < this.dParams.length - 1; i4++) {
            if (this.dParams[i4] < -1.0E-8d || this.dParams[i4] > 1.0E-8d) {
                this.sFeatures[i3] = i4;
                this.sParams[i3] = this.dParams[i4];
                i3++;
            }
        }
        this.sBias = this.dParams[this.dParams.length - 1];
        this.dParams = null;
    }

    public MalletLogisticRegression(Alphabet alphabet) {
        this.name = "Logistic Regression";
        this.predicate = "predicate not given";
        this.l1wt = 0.0d;
        this.l2wt = 0.0d;
        this.alphabet = alphabet;
        this.sFeatures = new int[0];
        this.sParams = new double[0];
        this.sBias = Double.NaN;
        this.dParams = null;
        this.preferSparse = false;
    }

    public MalletLogisticRegression(Alphabet alphabet, boolean z) {
        this(alphabet);
        this.preferSparse = z;
    }

    public void train(InstanceList instanceList) {
        train(instanceList, 999);
    }

    public void train(InstanceList instanceList, int i) {
        if (!$assertionsDisabled && this.alphabet != instanceList.getDataAlphabet()) {
            throw new AssertionError();
        }
        setDense();
        if (Double.isNaN(this.dParams[this.dParams.length - 1]) || Double.isInfinite(this.dParams[this.dParams.length - 1])) {
            this.dParams[this.dParams.length - 1] = 0.0d;
        }
        int i2 = 0;
        boolean z = false;
        LimitedMemoryBFGS optimizer = getOptimizer(instanceList);
        while (!z && i2 < i) {
            try {
                z = optimizer.optimize(1);
            } catch (OptimizationException e) {
                log.warn("Catching " + e + "! saying converged.");
                z = true;
            }
            i2++;
        }
        boolean z2 = false;
        if (i2 < i) {
            if (optimizer instanceof LimitedMemoryBFGS) {
                optimizer.reset();
            } else {
                optimizer = getOptimizer(instanceList);
            }
            while (!z2 && i2 < i) {
                try {
                    z2 = optimizer.optimize(1);
                } catch (OptimizationException e2) {
                    log.warn("Catching " + e2 + "! saying converged.");
                    z2 = true;
                }
                i2++;
            }
        }
        log.info(this.predicate + ": L-BFGS Converged after " + i2 + " iterations!");
        if (this.preferSparse) {
            setSparse();
            log.debug(this.sParams.length + " of " + this.alphabet.size() + " parameters were nonzero after training");
        }
    }

    private Optimizer getOptimizer(InstanceList instanceList) {
        LogRegOptimizable logRegOptimizable = new LogRegOptimizable(instanceList);
        if (this.l1wt > 0.0d) {
            log.info("Using L1 regularization (l1wt:" + this.l1wt + ",l2wt:" + this.l2wt + ")");
            return new OrthantWiseLimitedMemoryBFGS(logRegOptimizable, this.l1wt);
        }
        log.info("Using L2 regularization (l2wt:" + this.l2wt + ")");
        return new LimitedMemoryBFGS(logRegOptimizable);
    }

    public double classify(Instance instance) {
        try {
            if (this.dParams != null) {
                double d = this.dParams[this.dParams.length - 1];
                FeatureVector featureVector = (FeatureVector) instance.getData();
                int numLocations = featureVector.numLocations();
                for (int i = 0; i < numLocations; i++) {
                    d += featureVector.valueAtLocation(i) * this.dParams[featureVector.indexAtLocation(i)];
                }
                return sigmoid(d);
            }
            double d2 = this.sBias;
            FeatureVector featureVector2 = (FeatureVector) instance.getData();
            int i2 = 0;
            int i3 = 0;
            while (i2 < featureVector2.numLocations() && i3 < this.sFeatures.length) {
                int indexAtLocation = featureVector2.indexAtLocation(i2);
                int i4 = this.sFeatures[i3];
                if (indexAtLocation == i4) {
                    d2 += featureVector2.valueAtLocation(i2) * this.sParams[i3];
                    i2++;
                    i3++;
                } else if (indexAtLocation < i4) {
                    i2++;
                } else {
                    i3++;
                }
            }
            return sigmoid(d2);
        } catch (Exception e) {
            throw new RuntimeException("classify(" + instance.getData() + ")", e);
        }
    }

    private double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public double getL1wt() {
        return this.l1wt;
    }

    public void setL1wt(double d) {
        this.l1wt = d;
    }

    public double getL2wt() {
        return this.l2wt;
    }

    public void setL2wt(double d) {
        this.l2wt = d;
    }

    public double[] getSparseParams() {
        boolean z = this.sParams == null;
        if (z) {
            setSparse();
        }
        double[] dArr = this.sParams;
        if (z) {
            setDense();
        }
        return dArr;
    }

    public int[] getSparseFeatures() {
        boolean z = this.sParams == null;
        if (z) {
            setSparse();
        }
        int[] iArr = this.sFeatures;
        if (z) {
            setDense();
        }
        return iArr;
    }

    public double getBias() {
        return this.dParams == null ? this.sBias : this.dParams[this.dParams.length - 1];
    }

    static {
        $assertionsDisabled = !MalletLogisticRegression.class.desiredAssertionStatus();
        log = Logger.getLogger(MalletLogisticRegression.class);
    }
}
