package smile.regression;

import java.io.Serializable;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.math.matrix.BiconjugateGradient;
import smile.math.matrix.Matrix;
import smile.math.matrix.NaiveMatrix;
import smile.math.matrix.Preconditioner;
import smile.math.matrix.RowMajorMatrix;
import smile.math.matrix.SparseMatrix;
import smile.math.special.Beta;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/regression/LASSO.class
 */
/* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/regression/LASSO.class */
public class LASSO implements Regression<double[]>, Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) LASSO.class);
    private int p;
    private double lambda;
    private double b;
    private double[] w;
    private double ym;
    private double[] center;
    private double[] scale;
    private double[] residuals;
    private double RSS;
    private double error;
    private int df;
    private double RSquared;
    private double adjustedRSquared;
    private double F;
    private double pvalue;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Classes with same name are omitted:
      input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/regression/LASSO$PCGMatrix.class
     */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/regression/LASSO$PCGMatrix.class */
    public class PCGMatrix extends Matrix implements Preconditioner {
        Matrix A;
        Matrix AtA;
        double[] d1;
        double[] d2;
        double[] prb;
        double[] prs;
        double[] ax;
        double[] atax;

        PCGMatrix(Matrix matrix, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
            this.A = matrix;
            this.d1 = dArr;
            this.d2 = dArr2;
            this.prb = dArr3;
            this.prs = dArr4;
            this.ax = new double[matrix.nrows()];
            this.atax = new double[LASSO.this.p];
            if (matrix.ncols() >= 10000 || (matrix instanceof SparseMatrix)) {
                return;
            }
            this.AtA = matrix.ata();
        }

        @Override // smile.math.matrix.Matrix
        public int nrows() {
            return 2 * LASSO.this.p;
        }

        @Override // smile.math.matrix.Matrix
        public int ncols() {
            return 2 * LASSO.this.p;
        }

        @Override // smile.math.matrix.Matrix
        public double[] ax(double[] dArr, double[] dArr2) {
            if (this.AtA != null) {
                this.AtA.ax(dArr, this.atax);
            } else {
                this.A.ax(dArr, this.ax);
                this.A.atx(this.ax, this.atax);
            }
            for (int i = 0; i < LASSO.this.p; i++) {
                dArr2[i] = (2.0d * this.atax[i]) + (this.d1[i] * dArr[i]) + (this.d2[i] * dArr[i + LASSO.this.p]);
                dArr2[i + LASSO.this.p] = (this.d2[i] * dArr[i]) + (this.d1[i] * dArr[i + LASSO.this.p]);
            }
            return dArr2;
        }

        @Override // smile.math.matrix.Matrix
        public double[] atx(double[] dArr, double[] dArr2) {
            return ax(dArr, dArr2);
        }

        @Override // smile.math.matrix.Preconditioner
        public void asolve(double[] dArr, double[] dArr2) {
            for (int i = 0; i < LASSO.this.p; i++) {
                dArr2[i] = ((this.d1[i] * dArr[i]) - (this.d2[i] * dArr[i + LASSO.this.p])) / this.prs[i];
                dArr2[i + LASSO.this.p] = (((-this.d2[i]) * dArr[i]) + (this.prb[i] * dArr[i + LASSO.this.p])) / this.prs[i];
            }
        }

        @Override // smile.math.matrix.Matrix
        public Matrix transpose() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public Matrix aat() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public Matrix ata() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double get(int i, int i2) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double apply(int i, int i2) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double[] axpy(double[] dArr, double[] dArr2) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double[] axpy(double[] dArr, double[] dArr2, double d) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double[] atxpy(double[] dArr, double[] dArr2) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double[] atxpy(double[] dArr, double[] dArr2, double d) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double[] diag() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // smile.math.matrix.Matrix
        public double trace() {
            throw new UnsupportedOperationException("Not supported yet.");
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/regression/LASSO$Trainer.class
     */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/regression/LASSO$Trainer.class */
    public static class Trainer extends RegressionTrainer<double[]> {
        private double lambda;
        private double tol = 0.001d;
        private int maxIter = 1000;

        public Trainer(double d) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + d);
            }
            this.lambda = d;
        }

        public Trainer setTolerance(double d) {
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Invalid tolerance: " + d);
            }
            this.tol = d;
            return this;
        }

        public Trainer setMaxNumIteration(int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
            }
            this.maxIter = i;
            return this;
        }

        @Override // smile.regression.RegressionTrainer
        public LASSO train(double[][] dArr, double[] dArr2) {
            return new LASSO(dArr, dArr2, this.lambda, this.tol, this.maxIter);
        }

        public LASSO train(Matrix matrix, double[] dArr) {
            return new LASSO(matrix, dArr, this.lambda, this.tol, this.maxIter);
        }
    }

    public LASSO(double[][] dArr, double[] dArr2, double d) {
        this(dArr, dArr2, d, 1.0E-4d, 1000);
    }

    public LASSO(double[][] dArr, double[] dArr2, double d, double d2, int i) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        this.center = Math.colMean(dArr);
        RowMajorMatrix rowMajorMatrix = new RowMajorMatrix(length, length2);
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < length2; i3++) {
                rowMajorMatrix.set(i2, i3, dArr[i2][i3] - this.center[i3]);
            }
        }
        this.scale = new double[length2];
        for (int i4 = 0; i4 < length2; i4++) {
            for (int i5 = 0; i5 < length; i5++) {
                double[] dArr3 = this.scale;
                int i6 = i4;
                dArr3[i6] = dArr3[i6] + Math.sqr(rowMajorMatrix.get(i5, i4));
            }
            this.scale[i4] = Math.sqrt(this.scale[i4] / length);
        }
        for (int i7 = 0; i7 < length2; i7++) {
            if (!Math.isZero(this.scale[i7])) {
                for (int i8 = 0; i8 < length; i8++) {
                    rowMajorMatrix.div(i8, i7, this.scale[i7]);
                }
            }
        }
        train(rowMajorMatrix, dArr2, d, d2, i);
        for (int i9 = 0; i9 < length2; i9++) {
            if (!Math.isZero(this.scale[i9])) {
                double[] dArr4 = this.w;
                int i10 = i9;
                dArr4[i10] = dArr4[i10] / this.scale[i9];
            }
        }
        this.b = this.ym - Math.dot(this.w, this.center);
        fitness(new NaiveMatrix(dArr), dArr2);
    }

    public LASSO(Matrix matrix, double[] dArr, double d) {
        this(matrix, dArr, d, 1.0E-4d, 1000);
    }

    public LASSO(Matrix matrix, double[] dArr, double d, double d2, int i) {
        train(matrix, dArr, d, d2, i);
        fitness(matrix, dArr);
    }

    private void train(Matrix matrix, double[] dArr, double d, double d2, int i) {
        if (matrix.nrows() != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(matrix.nrows()), Integer.valueOf(dArr.length)));
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        boolean z = false;
        int nrows = matrix.nrows();
        this.p = matrix.ncols();
        double[] dArr2 = new double[nrows];
        this.ym = Math.mean(dArr);
        for (int i2 = 0; i2 < nrows; i2++) {
            dArr2[i2] = dArr[i2] - this.ym;
        }
        double min = Math.min(Math.max(1.0d, 1.0d / d), (2 * this.p) / 0.001d);
        double d3 = Double.NEGATIVE_INFINITY;
        double d4 = Double.POSITIVE_INFINITY;
        this.w = new double[this.p];
        this.b = this.ym;
        double[] dArr3 = new double[this.p];
        double[] dArr4 = new double[nrows];
        double[][] dArr5 = new double[2][this.p];
        Arrays.fill(dArr3, 1.0d);
        for (int i3 = 0; i3 < this.p; i3++) {
            dArr5[0][i3] = this.w[i3] - dArr3[i3];
            dArr5[1][i3] = (-this.w[i3]) - dArr3[i3];
        }
        double[] dArr6 = new double[this.p];
        double[] dArr7 = new double[this.p];
        double[] dArr8 = new double[nrows];
        double[][] dArr9 = new double[2][this.p];
        double[] dArr10 = new double[this.p];
        double[] dArr11 = new double[this.p];
        double[] dArr12 = new double[2 * this.p];
        double[] dArr13 = new double[2 * this.p];
        double[] dArr14 = new double[this.p];
        Arrays.fill(dArr14, 2.0d);
        double[] dArr15 = new double[nrows];
        double[] dArr16 = new double[this.p];
        double[] dArr17 = new double[this.p];
        double[] dArr18 = new double[this.p];
        double[] dArr19 = new double[this.p];
        double[] dArr20 = new double[this.p];
        double[][] dArr21 = new double[2][this.p];
        double[] dArr22 = new double[this.p];
        double[] dArr23 = new double[this.p];
        PCGMatrix pCGMatrix = new PCGMatrix(matrix, dArr19, dArr20, dArr22, dArr23);
        int i4 = 0;
        while (true) {
            if (i4 > i) {
                break;
            }
            matrix.ax(this.w, dArr4);
            for (int i5 = 0; i5 < nrows; i5++) {
                int i6 = i5;
                dArr4[i6] = dArr4[i6] - dArr2[i5];
                dArr15[i5] = 2.0d * dArr4[i5];
            }
            matrix.atx(dArr15, dArr16);
            double normInf = Math.normInf(dArr16);
            if (normInf > d) {
                double d5 = d / normInf;
                for (int i7 = 0; i7 < nrows; i7++) {
                    int i8 = i7;
                    dArr15[i8] = dArr15[i8] * d5;
                }
            }
            double dot = Math.dot(dArr4, dArr4) + (d * Math.norm1(this.w));
            d3 = Math.max(((-0.25d) * Math.dot(dArr15, dArr15)) - Math.dot(dArr15, dArr2), d3);
            if (i4 % 10 == 0) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", Integer.valueOf(i4), Double.valueOf(dot), Double.valueOf(d3)));
            }
            double d6 = dot - d3;
            if (d6 / d3 < d2) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", Integer.valueOf(i4), Double.valueOf(dot), Double.valueOf(d3)));
                break;
            }
            if (d4 >= 0.5d) {
                min = Math.max(Math.min(((2 * this.p) * 2) / d6, 2.0d * min), min);
            }
            for (int i9 = 0; i9 < this.p; i9++) {
                double d7 = 1.0d / (dArr3[i9] + this.w[i9]);
                double d8 = 1.0d / (dArr3[i9] - this.w[i9]);
                dArr17[i9] = d7;
                dArr18[i9] = d8;
                dArr19[i9] = ((d7 * d7) + (d8 * d8)) / min;
                dArr20[i9] = ((d7 * d7) - (d8 * d8)) / min;
            }
            matrix.atx(dArr4, dArr21[0]);
            for (int i10 = 0; i10 < this.p; i10++) {
                dArr21[0][i10] = (2.0d * dArr21[0][i10]) - ((dArr17[i10] - dArr18[i10]) / min);
                dArr21[1][i10] = d - ((dArr17[i10] + dArr18[i10]) / min);
                dArr13[i10] = -dArr21[0][i10];
                dArr13[i10 + this.p] = -dArr21[1][i10];
            }
            for (int i11 = 0; i11 < this.p; i11++) {
                dArr22[i11] = dArr14[i11] + dArr19[i11];
                dArr23[i11] = (dArr22[i11] * dArr19[i11]) - (dArr20[i11] * dArr20[i11]);
            }
            double min2 = Math.min(0.1d, (0.001d * d6) / Math.min(1.0d, Math.norm(dArr13)));
            if (i4 != 0 && !z) {
                min2 *= 0.1d;
            }
            if (BiconjugateGradient.solve(pCGMatrix, pCGMatrix, dArr13, dArr12, min2, 1, 5000) > min2) {
                z = 5000;
            }
            for (int i12 = 0; i12 < this.p; i12++) {
                dArr10[i12] = dArr12[i12];
                dArr11[i12] = dArr12[i12 + this.p];
            }
            double dot2 = (Math.dot(dArr4, dArr4) + (d * Math.sum(dArr3))) - (sumlogneg(dArr5) / min);
            d4 = 1.0d;
            double dot3 = Math.dot(dArr13, dArr12);
            int i13 = 0;
            while (i13 < 100) {
                for (int i14 = 0; i14 < this.p; i14++) {
                    dArr6[i14] = this.w[i14] + (d4 * dArr10[i14]);
                    dArr7[i14] = dArr3[i14] + (d4 * dArr11[i14]);
                    dArr9[0][i14] = dArr6[i14] - dArr7[i14];
                    dArr9[1][i14] = (-dArr6[i14]) - dArr7[i14];
                }
                if (Math.max(dArr9) < 0.0d) {
                    matrix.ax(dArr6, dArr8);
                    for (int i15 = 0; i15 < nrows; i15++) {
                        int i16 = i15;
                        dArr8[i16] = dArr8[i16] - dArr2[i15];
                    }
                    if (((Math.dot(dArr8, dArr8) + (d * Math.sum(dArr7))) - (sumlogneg(dArr9) / min)) - dot2 <= 0.01d * d4 * dot3) {
                        break;
                    }
                }
                d4 = 0.5d * d4;
                i13++;
            }
            if (i13 == 100) {
                logger.error("LASSO: Too many iterations of line search.");
                break;
            }
            System.arraycopy(dArr6, 0, this.w, 0, this.p);
            System.arraycopy(dArr7, 0, dArr3, 0, this.p);
            System.arraycopy(dArr9[0], 0, dArr5[0], 0, this.p);
            System.arraycopy(dArr9[1], 0, dArr5[1], 0, this.p);
            i4++;
        }
        if (i4 == i) {
            logger.error("LASSO: Too many iterations.");
        }
    }

    private void fitness(Matrix matrix, double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        matrix.ax(this.w, dArr2);
        double d = 0.0d;
        this.RSS = 0.0d;
        double mean = Math.mean(dArr);
        this.residuals = new double[length];
        for (int i = 0; i < length; i++) {
            double d2 = (dArr[i] - dArr2[i]) - this.b;
            this.residuals[i] = d2;
            this.RSS += Math.sqr(d2);
            d += Math.sqr(dArr[i] - mean);
        }
        this.error = Math.sqrt(this.RSS / ((length - this.p) - 1));
        this.df = (length - this.p) - 1;
        this.RSquared = 1.0d - (this.RSS / d);
        this.adjustedRSquared = 1.0d - (((1.0d - this.RSquared) * (length - 1)) / ((length - this.p) - 1));
        this.F = ((d - this.RSS) * ((length - this.p) - 1)) / (this.RSS * this.p);
        int i2 = this.p;
        int i3 = (length - this.p) - 1;
        this.pvalue = Beta.regularizedIncompleteBetaFunction(0.5d * i3, 0.5d * i2, i3 / (i3 + (i2 * this.F)));
    }

    private double sumlogneg(double[][] dArr) {
        int length = dArr[0].length;
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            for (int i = 0; i < length; i++) {
                d += Math.log(-dArr2[i]);
            }
        }
        return d;
    }

    public double[] coefficients() {
        return this.w;
    }

    public double intercept() {
        return this.b;
    }

    public double shrinkage() {
        return this.lambda;
    }

    @Override // smile.regression.Regression
    public double predict(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        return Math.dot(dArr, this.w) + this.b;
    }

    public double[] residuals() {
        return this.residuals;
    }

    public double RSS() {
        return this.RSS;
    }

    public double error() {
        return this.error;
    }

    public int df() {
        return this.df;
    }

    public double RSquared() {
        return this.RSquared;
    }

    public double adjustedRSquared() {
        return this.adjustedRSquared;
    }

    public double ftest() {
        return this.F;
    }

    public double pvalue() {
        return this.pvalue;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("LASSO:\n");
        double[] dArr = (double[]) this.residuals.clone();
        sb.append("\nResiduals:\n");
        sb.append("\t       Min\t        1Q\t    Median\t        3Q\t       Max\n");
        sb.append(String.format("\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.4f%n", Double.valueOf(Math.min(dArr)), Double.valueOf(Math.q1(dArr)), Double.valueOf(Math.median(dArr)), Double.valueOf(Math.q3(dArr)), Double.valueOf(Math.max(dArr))));
        sb.append("\nCoefficients:\n");
        sb.append("            Estimate\n");
        sb.append(String.format("Intercept%11.4f%n", Double.valueOf(this.b)));
        for (int i = 0; i < this.p; i++) {
            sb.append(String.format("Var %d\t %11.4f%n", Integer.valueOf(i + 1), Double.valueOf(this.w[i])));
        }
        sb.append(String.format("\nResidual standard error: %.4f on %d degrees of freedom%n", Double.valueOf(this.error), Integer.valueOf(this.df)));
        sb.append(String.format("Multiple R-squared: %.4f,    Adjusted R-squared: %.4f%n", Double.valueOf(this.RSquared), Double.valueOf(this.adjustedRSquared)));
        sb.append(String.format("F-statistic: %.4f on %d and %d DF,  p-value: %.4g%n", Double.valueOf(this.F), Integer.valueOf(this.p), Integer.valueOf(this.df), Double.valueOf(this.pvalue)));
        return sb.toString();
    }
}
