package smile.math.matrix;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/math/matrix/BiconjugateGradient.class
 */
/* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/math/matrix/BiconjugateGradient.class */
public class BiconjugateGradient {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) BiconjugateGradient.class);

    private static Preconditioner diagonalPreconditioner(final Matrix matrix) {
        return new Preconditioner() { // from class: smile.math.matrix.BiconjugateGradient.1
            @Override // smile.math.matrix.Preconditioner
            public void asolve(double[] dArr, double[] dArr2) {
                double[] diag = Matrix.this.diag();
                int length = diag.length;
                for (int i = 0; i < length; i++) {
                    dArr2[i] = diag[i] != 0.0d ? dArr[i] / diag[i] : dArr[i];
                }
            }
        };
    }

    public static double solve(Matrix matrix, double[] dArr, double[] dArr2) {
        return solve(matrix, diagonalPreconditioner(matrix), dArr, dArr2);
    }

    public static double solve(Matrix matrix, Preconditioner preconditioner, double[] dArr, double[] dArr2) {
        return solve(matrix, preconditioner, dArr, dArr2, 1.0E-10d);
    }

    public static double solve(Matrix matrix, double[] dArr, double[] dArr2, double d) {
        return solve(matrix, diagonalPreconditioner(matrix), dArr, dArr2, d);
    }

    public static double solve(Matrix matrix, Preconditioner preconditioner, double[] dArr, double[] dArr2, double d) {
        return solve(matrix, preconditioner, dArr, dArr2, d, 1);
    }

    public static double solve(Matrix matrix, double[] dArr, double[] dArr2, double d, int i) {
        return solve(matrix, diagonalPreconditioner(matrix), dArr, dArr2, d, i);
    }

    public static double solve(Matrix matrix, Preconditioner preconditioner, double[] dArr, double[] dArr2, double d, int i) {
        return solve(matrix, preconditioner, dArr, dArr2, d, i, 2 * Math.max(matrix.nrows(), matrix.ncols()));
    }

    public static double solve(Matrix matrix, double[] dArr, double[] dArr2, double d, int i, int i2) {
        return solve(matrix, diagonalPreconditioner(matrix), dArr, dArr2, d, i, i2);
    }

    public static double solve(Matrix matrix, Preconditioner preconditioner, double[] dArr, double[] dArr2, double d, int i, int i2) {
        double snorm;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        if (i < 1 || i > 4) {
            throw new IllegalArgumentException(String.format("Illegal itol: %d", Integer.valueOf(i)));
        }
        double d2 = 0.0d;
        double d3 = 1.0d;
        double d4 = 0.0d;
        int length = dArr.length;
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        double[] dArr5 = new double[length];
        double[] dArr6 = new double[length];
        double[] dArr7 = new double[length];
        double[] dArr8 = new double[length];
        matrix.ax(dArr2, dArr5);
        for (int i3 = 0; i3 < length; i3++) {
            dArr5[i3] = dArr[i3] - dArr5[i3];
            dArr6[i3] = dArr5[i3];
        }
        if (i == 1) {
            snorm = snorm(dArr, i);
            preconditioner.asolve(dArr5, dArr7);
        } else if (i == 2) {
            preconditioner.asolve(dArr, dArr7);
            snorm = snorm(dArr7, i);
            preconditioner.asolve(dArr5, dArr7);
        } else {
            if (i != 3 && i != 4) {
                throw new IllegalArgumentException(String.format("Illegal itol: %d", Integer.valueOf(i)));
            }
            preconditioner.asolve(dArr, dArr7);
            snorm = snorm(dArr7, i);
            preconditioner.asolve(dArr5, dArr7);
            d4 = snorm(dArr7, i);
        }
        int i4 = 1;
        while (true) {
            if (i4 > i2) {
                break;
            }
            preconditioner.asolve(dArr6, dArr8);
            double d5 = 0.0d;
            for (int i5 = 0; i5 < length; i5++) {
                d5 += dArr7[i5] * dArr6[i5];
            }
            if (i4 == 1) {
                for (int i6 = 0; i6 < length; i6++) {
                    dArr3[i6] = dArr7[i6];
                    dArr4[i6] = dArr8[i6];
                }
            } else {
                double d6 = d5 / d3;
                for (int i7 = 0; i7 < length; i7++) {
                    dArr3[i7] = (d6 * dArr3[i7]) + dArr7[i7];
                    dArr4[i7] = (d6 * dArr4[i7]) + dArr8[i7];
                }
            }
            d3 = d5;
            matrix.ax(dArr3, dArr7);
            double d7 = 0.0d;
            for (int i8 = 0; i8 < length; i8++) {
                d7 += dArr7[i8] * dArr4[i8];
            }
            double d8 = d5 / d7;
            matrix.atx(dArr4, dArr8);
            for (int i9 = 0; i9 < length; i9++) {
                int i10 = i9;
                dArr2[i10] = dArr2[i10] + (d8 * dArr3[i9]);
                int i11 = i9;
                dArr5[i11] = dArr5[i11] - (d8 * dArr7[i9]);
                int i12 = i9;
                dArr6[i12] = dArr6[i12] - (d8 * dArr8[i9]);
            }
            preconditioner.asolve(dArr5, dArr7);
            if (i == 1) {
                d2 = snorm(dArr5, i) / snorm;
            } else if (i == 2) {
                d2 = snorm(dArr7, i) / snorm;
            } else if (i == 3 || i == 4) {
                double d9 = d4;
                d4 = snorm(dArr7, i);
                if (Math.abs(d9 - d4) > Math.EPSILON * d4) {
                    double abs = (d4 / Math.abs(d9 - d4)) * Math.abs(d8) * snorm(dArr3, i);
                    double snorm2 = snorm(dArr2, i);
                    if (abs <= 0.5d * snorm2) {
                        d2 = abs / snorm2;
                    } else {
                        d2 = d4 / snorm;
                    }
                } else {
                    d2 = d4 / snorm;
                }
                i4++;
            }
            if (i4 % 10 == 0) {
                logger.info(String.format("BCG: the error after %3d iterations: %.5g", Integer.valueOf(i4), Double.valueOf(d2)));
            }
            if (d2 <= d) {
                logger.info(String.format("BCG: the error after %3d iterations: %.5g", Integer.valueOf(i4), Double.valueOf(d2)));
                break;
            }
            i4++;
        }
        return d2;
    }

    private static double snorm(double[] dArr, int i) {
        int length = dArr.length;
        if (i <= 3) {
            double d = 0.0d;
            for (int i2 = 0; i2 < length; i2++) {
                d += dArr[i2] * dArr[i2];
            }
            return Math.sqrt(d);
        }
        int i3 = 0;
        for (int i4 = 0; i4 < length; i4++) {
            if (Math.abs(dArr[i4]) > Math.abs(dArr[i3])) {
                i3 = i4;
            }
        }
        return Math.abs(dArr[i3]);
    }
}
