package smile.classification;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import smile.math.Math;
import smile.math.distance.Metric;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.QRDecomposition;
import smile.math.rbf.RadialBasisFunction;
import smile.util.SmileUtils;

/* loaded from: input_file:libarx-3.7.1.jar:smile/classification/RBFNetwork.class */
public class RBFNetwork<T> implements Classifier<T>, Serializable {
    private static final long serialVersionUID = 1;
    private int k;
    private T[] centers;
    private DenseMatrix w;
    private Metric<T> distance;
    private RadialBasisFunction[] rbf;
    private boolean normalized;

    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/RBFNetwork$Trainer.class */
    public static class Trainer<T> extends ClassifierTrainer<T> {
        private int m;
        private Metric<T> distance;
        private RadialBasisFunction[] rbf;
        private boolean normalized;

        public Trainer(Metric<T> metric, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.m = 10;
            this.normalized = false;
            this.distance = metric;
        }

        public Trainer<T> setNormalized(boolean z) {
            this.normalized = z;
            return this;
        }

        public Trainer<T> setRBF(RadialBasisFunction radialBasisFunction, int i) {
            this.m = i;
            this.rbf = RBFNetwork.rep(radialBasisFunction, i);
            return this;
        }

        public Trainer<T> setRBF(RadialBasisFunction[] radialBasisFunctionArr) {
            this.m = radialBasisFunctionArr.length;
            this.rbf = radialBasisFunctionArr;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public RBFNetwork<T> train(T[] tArr, int[] iArr) {
            Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), this.m);
            return this.rbf == null ? new RBFNetwork<>(tArr, iArr, this.distance, SmileUtils.learnGaussianRadialBasis(tArr, objArr, this.distance), objArr, this.normalized) : new RBFNetwork<>(tArr, iArr, this.distance, this.rbf, objArr, this.normalized);
        }

        public RBFNetwork<T> train(T[] tArr, int[] iArr, T[] tArr2) {
            return new RBFNetwork<>(tArr, iArr, this.distance, this.rbf, tArr2, this.normalized);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static RadialBasisFunction[] rep(RadialBasisFunction radialBasisFunction, int i) {
        RadialBasisFunction[] radialBasisFunctionArr = new RadialBasisFunction[i];
        Arrays.fill(radialBasisFunctionArr, radialBasisFunction);
        return radialBasisFunctionArr;
    }

    public RBFNetwork(T[] tArr, int[] iArr, Metric<T> metric, RadialBasisFunction radialBasisFunction, T[] tArr2) {
        this((Object[]) tArr, iArr, (Metric) metric, radialBasisFunction, (Object[]) tArr2, false);
    }

    public RBFNetwork(T[] tArr, int[] iArr, Metric<T> metric, RadialBasisFunction radialBasisFunction, T[] tArr2, boolean z) {
        this(tArr, iArr, metric, rep(radialBasisFunction, tArr2.length), tArr2, z);
    }

    public RBFNetwork(T[] tArr, int[] iArr, Metric<T> metric, RadialBasisFunction[] radialBasisFunctionArr, T[] tArr2) {
        this((Object[]) tArr, iArr, (Metric) metric, radialBasisFunctionArr, (Object[]) tArr2, false);
    }

    public RBFNetwork(T[] tArr, int[] iArr, Metric<T> metric, RadialBasisFunction[] radialBasisFunctionArr, T[] tArr2, boolean z) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        if (radialBasisFunctionArr.length != tArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of RBF functions and centers don't match: %d != %d", Integer.valueOf(radialBasisFunctionArr.length), Integer.valueOf(tArr2.length)));
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i = 0; i < unique.length; i++) {
            if (unique[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i]);
            }
            if (i > 0 && unique[i] - unique[i - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        this.centers = tArr2;
        this.distance = metric;
        this.rbf = radialBasisFunctionArr;
        this.normalized = z;
        int length = tArr.length;
        int length2 = radialBasisFunctionArr.length;
        this.w = new ColumnMajorMatrix(length2 + 1, this.k);
        ColumnMajorMatrix columnMajorMatrix = new ColumnMajorMatrix(length, length2 + 1);
        ColumnMajorMatrix columnMajorMatrix2 = new ColumnMajorMatrix(length, this.k);
        for (int i2 = 0; i2 < length; i2++) {
            double d = 0.0d;
            for (int i3 = 0; i3 < length2; i3++) {
                double f = radialBasisFunctionArr[i3].f(metric.d(tArr[i2], tArr2[i3]));
                columnMajorMatrix.set(i2, i3, f);
                d += f;
            }
            columnMajorMatrix.set(i2, length2, 1.0d);
            if (z) {
                columnMajorMatrix2.set(i2, iArr[i2], d);
            } else {
                columnMajorMatrix2.set(i2, iArr[i2], 1.0d);
            }
        }
        new QRDecomposition(columnMajorMatrix).solve(columnMajorMatrix2, this.w);
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        double[] dArr = new double[this.k];
        double d = 0.0d;
        for (int i = 0; i < this.rbf.length; i++) {
            double f = this.rbf[i].f(this.distance.d(t, this.centers[i]));
            d += f;
            for (int i2 = 0; i2 < this.k; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (this.w.get(i, i2) * f);
            }
        }
        if (this.normalized) {
            for (int i4 = 0; i4 < this.k; i4++) {
                dArr[i4] = (dArr[i4] + this.w.get(this.centers.length, i4)) / d;
            }
        } else {
            for (int i5 = 0; i5 < this.k; i5++) {
                int i6 = i5;
                dArr[i6] = dArr[i6] + this.w.get(this.centers.length, i5);
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i7 = 0;
        for (int i8 = 0; i8 < this.k; i8++) {
            if (d2 < dArr[i8]) {
                d2 = dArr[i8];
                i7 = i8;
            }
        }
        return i7;
    }
}
