package smile.classification;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DoubleArrayList;
import smile.math.Math;
import smile.math.SparseArray;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import smile.util.MulticoreExecutor;

/* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM.class */
public class SVM<T> extends SoftClassifier<T> implements Serializable, OnlineClassifier<T> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) SVM.class);
    private static final double TAU = 1.0E-12d;
    private SVM<T>.LASVM svm;
    private List<SVM<T>.LASVM> svms;
    private MercerKernel<T> kernel;
    private int p;
    private int k;
    private Multiclass strategy;
    private double[] wi;
    private double tol;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$LASVM.class */
    public final class LASVM implements Serializable {
        private static final long serialVersionUID = 1;
        private double Cp;
        private double Cn;
        double[] w;
        PlattScaling platt;
        List<SVM<T>.LASVM.SupportVector> sv = new ArrayList();
        double b = 0.0d;
        int nsv = 0;
        int nbsv = 0;
        transient boolean minmaxflag = false;
        transient SVM<T>.LASVM.SupportVector svmin = null;
        transient SVM<T>.LASVM.SupportVector svmax = null;
        transient double gmin = Double.MAX_VALUE;
        transient double gmax = -1.7976931348623157E308d;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$LASVM$SupportVector.class */
        public class SupportVector implements Serializable {
            private static final long serialVersionUID = 1;
            T x;
            int y;
            double alpha;
            double g;
            double cmin;
            double cmax;
            double k;
            DoubleArrayList kcache;

            SupportVector() {
            }
        }

        LASVM(double d, double d2) {
            this.Cp = 1.0d;
            this.Cn = 1.0d;
            this.Cp = d;
            this.Cn = d2;
        }

        void cleanup() {
            this.nsv = 0;
            this.nbsv = 0;
            for (SVM<T>.LASVM.SupportVector supportVector : this.sv) {
                if (supportVector != null) {
                    this.nsv++;
                    supportVector.kcache = null;
                    if (supportVector.alpha == supportVector.cmin || supportVector.alpha == supportVector.cmax) {
                        this.nbsv++;
                    }
                }
            }
            SVM.logger.info("{} support vectors, {} bounded\n", Integer.valueOf(this.nsv), Integer.valueOf(this.nbsv));
        }

        void evict() {
            minmax();
            for (int i = 0; i < this.sv.size(); i++) {
                SVM<T>.LASVM.SupportVector supportVector = this.sv.get(i);
                if (supportVector != null && supportVector.alpha == 0.0d && ((supportVector.g >= this.gmax && 0.0d >= supportVector.cmax) || (supportVector.g <= this.gmin && 0.0d <= supportVector.cmin))) {
                    this.sv.set(i, null);
                }
            }
        }

        void finish() {
            finish(SVM.this.tol);
        }

        void finish(double d) {
            SVM.logger.info("SVM finializes the training by reprocess.");
            int i = 1;
            while (smo(null, null, d)) {
                if (i % 1000 == 0) {
                    SVM.logger.info("finishing {} reprocess iterations.");
                }
                i++;
            }
            SVM.logger.info("SVM finished the reprocess.");
            Iterator<SVM<T>.LASVM.SupportVector> it = this.sv.iterator();
            while (it.hasNext()) {
                SVM<T>.LASVM.SupportVector next = it.next();
                if (next == null) {
                    it.remove();
                } else if (next.alpha == 0.0d && ((next.g >= this.gmax && 0.0d >= next.cmax) || (next.g <= this.gmin && 0.0d <= next.cmin))) {
                    it.remove();
                }
            }
            cleanup();
            if (SVM.this.kernel instanceof LinearKernel) {
                this.w = new double[SVM.this.p];
                for (SVM<T>.LASVM.SupportVector supportVector : this.sv) {
                    if (supportVector.x instanceof double[]) {
                        double[] dArr = (double[]) supportVector.x;
                        for (int i2 = 0; i2 < this.w.length; i2++) {
                            double[] dArr2 = this.w;
                            int i3 = i2;
                            dArr2[i3] = dArr2[i3] + (supportVector.alpha * dArr[i2]);
                        }
                    } else if (supportVector.x instanceof int[]) {
                        for (int i4 : (int[]) supportVector.x) {
                            double[] dArr3 = this.w;
                            dArr3[i4] = dArr3[i4] + supportVector.alpha;
                        }
                    } else if (supportVector.x instanceof SparseArray) {
                        Iterator<SparseArray.Entry> it2 = ((SparseArray) supportVector.x).iterator();
                        while (it2.hasNext()) {
                            SparseArray.Entry next2 = it2.next();
                            double[] dArr4 = this.w;
                            int i5 = next2.i;
                            dArr4[i5] = dArr4[i5] + (supportVector.alpha * next2.x);
                        }
                    }
                }
            }
        }

        void learn(T[] tArr, int[] iArr) {
            learn(tArr, iArr, null);
        }

        /* JADX WARN: Multi-variable type inference failed */
        void learn(T[] tArr, int[] iArr, double[] dArr) {
            if (SVM.this.p == 0 && (SVM.this.kernel instanceof LinearKernel)) {
                if (tArr instanceof double[][]) {
                    SVM.this.p = ((double[]) tArr[0]).length;
                } else {
                    if (!(tArr instanceof float[][])) {
                        throw new UnsupportedOperationException("Unsupported data type for linear kernel.");
                    }
                    SVM.this.p = ((float[]) tArr[0]).length;
                }
            }
            int i = 0;
            int i2 = 0;
            for (SVM<T>.LASVM.SupportVector supportVector : this.sv) {
                if (supportVector != null) {
                    if (supportVector.y > 0) {
                        i++;
                    } else if (supportVector.y < 0) {
                        i2++;
                    }
                }
            }
            int length = tArr.length;
            if (i < 5 || i2 < 5) {
                for (int i3 = 0; i3 < length; i3++) {
                    if (iArr[i3] == 1 && i < 5) {
                        if (dArr == null) {
                            process(tArr[i3], iArr[i3]);
                        } else {
                            process(tArr[i3], iArr[i3], dArr[i3]);
                        }
                        i++;
                    }
                    if (iArr[i3] == -1 && i2 < 5) {
                        if (dArr == null) {
                            process(tArr[i3], iArr[i3]);
                        } else {
                            process(tArr[i3], iArr[i3], dArr[i3]);
                        }
                        i2++;
                    }
                    if (i >= 5 && i2 >= 5) {
                        break;
                    }
                }
            }
            int[] permutate = Math.permutate(length);
            for (int i4 = 0; i4 < length; i4++) {
                if (dArr == null) {
                    process(tArr[permutate[i4]], iArr[permutate[i4]]);
                } else {
                    process(tArr[permutate[i4]], iArr[permutate[i4]], dArr[permutate[i4]]);
                }
                do {
                    reprocess(SVM.this.tol);
                    minmax();
                } while (this.gmax - this.gmin > 1000.0d);
            }
        }

        void minmax() {
            if (this.minmaxflag) {
                return;
            }
            this.gmin = Double.MAX_VALUE;
            this.gmax = -1.7976931348623157E308d;
            for (SVM<T>.LASVM.SupportVector supportVector : this.sv) {
                if (supportVector != null) {
                    double d = supportVector.g;
                    double d2 = supportVector.alpha;
                    if (d < this.gmin && d2 > supportVector.cmin) {
                        this.svmin = supportVector;
                        this.gmin = d;
                    }
                    if (d > this.gmax && d2 < supportVector.cmax) {
                        this.svmax = supportVector;
                        this.gmax = d;
                    }
                }
            }
            this.minmaxflag = true;
        }

        /* JADX WARN: Multi-variable type inference failed */
        double predict(T t) {
            double d = this.b;
            if (!(SVM.this.kernel instanceof LinearKernel) || this.w == null) {
                for (SVM<T>.LASVM.SupportVector supportVector : this.sv) {
                    if (supportVector != null) {
                        d += supportVector.alpha * SVM.this.kernel.k(supportVector.x, t);
                    }
                }
            } else if (t instanceof double[]) {
                d += Math.dot(this.w, (double[]) t);
            } else {
                if (!(t instanceof SparseArray)) {
                    throw new UnsupportedOperationException("Unsupported data type for linear kernel");
                }
                Iterator<SparseArray.Entry> it = ((SparseArray) t).iterator();
                while (it.hasNext()) {
                    SparseArray.Entry next = it.next();
                    d += this.w[next.i] * next.x;
                }
            }
            return d;
        }

        boolean process(T t, int i) {
            return process(t, i, 1.0d);
        }

        boolean process(T t, int i, double d) {
            if (i != 1 && i != -1) {
                throw new IllegalArgumentException("Invalid label: " + i);
            }
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Invalid instance weight: " + d);
            }
            double d2 = i;
            DoubleArrayList doubleArrayList = new DoubleArrayList(this.sv.size() + 1);
            if (!this.sv.isEmpty()) {
                for (SVM<T>.LASVM.SupportVector supportVector : this.sv) {
                    if (supportVector == null) {
                        doubleArrayList.add(0.0d);
                    } else {
                        if (supportVector.x == t) {
                            return true;
                        }
                        double k = SVM.this.kernel.k(supportVector.x, t);
                        d2 -= supportVector.alpha * k;
                        doubleArrayList.add(k);
                    }
                }
                minmax();
                if (this.gmin < this.gmax) {
                    if (i > 0 && d2 < this.gmin) {
                        return false;
                    }
                    if (i < 0 && d2 > this.gmax) {
                        return false;
                    }
                }
            }
            SVM<T>.LASVM.SupportVector supportVector2 = new SupportVector();
            supportVector2.x = t;
            supportVector2.y = i;
            supportVector2.alpha = 0.0d;
            supportVector2.g = d2;
            supportVector2.k = SVM.this.kernel.k(t, t);
            supportVector2.kcache = doubleArrayList;
            if (i > 0) {
                supportVector2.cmin = 0.0d;
                supportVector2.cmax = d * this.Cp;
            } else {
                supportVector2.cmin = (-d) * this.Cn;
                supportVector2.cmax = 0.0d;
            }
            int size = this.sv.size();
            while (true) {
                if (size >= this.sv.size()) {
                    break;
                }
                if (this.sv.get(size) == null) {
                    this.sv.set(size, supportVector2);
                    doubleArrayList.set(size, supportVector2.k);
                    for (int i2 = 0; i2 < this.sv.size(); i2++) {
                        SVM<T>.LASVM.SupportVector supportVector3 = this.sv.get(i2);
                        if (supportVector3 != null && supportVector3.kcache != null) {
                            supportVector3.kcache.set(size, doubleArrayList.get(i2));
                        }
                    }
                } else {
                    size++;
                }
            }
            if (size >= this.sv.size()) {
                for (int i3 = 0; i3 < this.sv.size(); i3++) {
                    SVM<T>.LASVM.SupportVector supportVector4 = this.sv.get(i3);
                    if (supportVector4 != null && supportVector4.kcache != null) {
                        supportVector4.kcache.add(doubleArrayList.get(i3));
                    }
                }
                supportVector2.kcache.add(supportVector2.k);
                this.sv.add(supportVector2);
            }
            if (i > 0) {
                smo(null, supportVector2, 0.0d);
            } else {
                smo(supportVector2, null, 0.0d);
            }
            this.minmaxflag = false;
            return true;
        }

        boolean reprocess(double d) {
            boolean smo = smo(null, null, d);
            evict();
            return smo;
        }

        boolean smo(SVM<T>.LASVM.SupportVector supportVector, SVM<T>.LASVM.SupportVector supportVector2, double d) {
            if (supportVector == null || supportVector2 == null) {
                if (supportVector == null && supportVector2 == null) {
                    minmax();
                    if (this.gmax > (-this.gmin)) {
                        supportVector2 = this.svmax;
                    } else {
                        supportVector = this.svmin;
                    }
                }
                if (supportVector2 == null) {
                    if (supportVector.kcache == null) {
                        supportVector.kcache = new DoubleArrayList(this.sv.size());
                        for (SVM<T>.LASVM.SupportVector supportVector3 : this.sv) {
                            if (supportVector3 != null) {
                                supportVector.kcache.add(SVM.this.kernel.k(supportVector.x, supportVector3.x));
                            } else {
                                supportVector.kcache.add(0.0d);
                            }
                        }
                    }
                    double d2 = supportVector.k;
                    double d3 = supportVector.g;
                    double d4 = 0.0d;
                    for (int i = 0; i < this.sv.size(); i++) {
                        SVM<T>.LASVM.SupportVector supportVector4 = this.sv.get(i);
                        if (supportVector4 != null) {
                            double d5 = supportVector4.g - d3;
                            double d6 = (d2 + supportVector4.k) - (2.0d * supportVector.kcache.get(i));
                            if (d6 <= 0.0d) {
                                d6 = 1.0E-12d;
                            }
                            double d7 = d5 / d6;
                            if ((d7 > 0.0d && supportVector4.alpha < supportVector4.cmax) || (d7 < 0.0d && supportVector4.alpha > supportVector4.cmin)) {
                                double d8 = d5 * d7;
                                if (d8 > d4) {
                                    d4 = d8;
                                    supportVector2 = supportVector4;
                                }
                            }
                        }
                    }
                } else {
                    if (supportVector2.kcache == null) {
                        supportVector2.kcache = new DoubleArrayList(this.sv.size());
                        for (SVM<T>.LASVM.SupportVector supportVector5 : this.sv) {
                            if (supportVector5 != null) {
                                supportVector2.kcache.add(SVM.this.kernel.k(supportVector2.x, supportVector5.x));
                            } else {
                                supportVector2.kcache.add(0.0d);
                            }
                        }
                    }
                    double d9 = supportVector2.k;
                    double d10 = supportVector2.g;
                    double d11 = 0.0d;
                    for (int i2 = 0; i2 < this.sv.size(); i2++) {
                        SVM<T>.LASVM.SupportVector supportVector6 = this.sv.get(i2);
                        if (supportVector6 != null) {
                            double d12 = d10 - supportVector6.g;
                            double d13 = (d9 + supportVector6.k) - (2.0d * supportVector2.kcache.get(i2));
                            if (d13 <= 0.0d) {
                                d13 = 1.0E-12d;
                            }
                            double d14 = d12 / d13;
                            if ((d14 > 0.0d && supportVector6.alpha > supportVector6.cmin) || (d14 < 0.0d && supportVector6.alpha < supportVector6.cmax)) {
                                double d15 = d12 * d14;
                                if (d15 > d11) {
                                    d11 = d15;
                                    supportVector = supportVector6;
                                }
                            }
                        }
                    }
                }
            }
            if (supportVector == null || supportVector2 == null) {
                return false;
            }
            if (supportVector.kcache == null) {
                supportVector.kcache = new DoubleArrayList(this.sv.size());
                for (SVM<T>.LASVM.SupportVector supportVector7 : this.sv) {
                    if (supportVector7 != null) {
                        supportVector.kcache.add(SVM.this.kernel.k(supportVector.x, supportVector7.x));
                    } else {
                        supportVector.kcache.add(0.0d);
                    }
                }
            }
            if (supportVector2.kcache == null) {
                supportVector2.kcache = new DoubleArrayList(this.sv.size());
                for (SVM<T>.LASVM.SupportVector supportVector8 : this.sv) {
                    if (supportVector8 != null) {
                        supportVector2.kcache.add(SVM.this.kernel.k(supportVector2.x, supportVector8.x));
                    } else {
                        supportVector2.kcache.add(0.0d);
                    }
                }
            }
            double k = (supportVector.k + supportVector2.k) - (2.0d * SVM.this.kernel.k(supportVector.x, supportVector2.x));
            if (k <= 0.0d) {
                k = 1.0E-12d;
            }
            double d16 = (supportVector2.g - supportVector.g) / k;
            if (d16 >= 0.0d) {
                double d17 = supportVector.alpha - supportVector.cmin;
                if (d17 < d16) {
                    d16 = d17;
                }
                double d18 = supportVector2.cmax - supportVector2.alpha;
                if (d18 < d16) {
                    d16 = d18;
                }
            } else {
                double d19 = supportVector2.cmin - supportVector2.alpha;
                if (d19 > d16) {
                    d16 = d19;
                }
                double d20 = supportVector.alpha - supportVector.cmax;
                if (d20 > d16) {
                    d16 = d20;
                }
            }
            supportVector.alpha -= d16;
            supportVector2.alpha += d16;
            for (int i3 = 0; i3 < this.sv.size(); i3++) {
                SVM<T>.LASVM.SupportVector supportVector9 = this.sv.get(i3);
                if (supportVector9 != null) {
                    supportVector9.g -= d16 * (supportVector2.kcache.get(i3) - supportVector.kcache.get(i3));
                }
            }
            this.minmaxflag = false;
            minmax();
            this.b = (this.gmax + this.gmin) / 2.0d;
            return this.gmax - this.gmin >= d;
        }

        void trainPlattScaling(T[] tArr, int[] iArr) {
            int length = iArr.length;
            double[] dArr = new double[length];
            for (int i = 0; i < length; i++) {
                dArr[i] = predict(tArr[i]);
            }
            this.platt = new PlattScaling(dArr, iArr);
        }
    }

    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$Multiclass.class */
    public enum Multiclass {
        ONE_VS_ONE,
        ONE_VS_ALL
    }

    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$PlattScalingTask.class */
    class PlattScalingTask implements Callable<SVM<T>.LASVM> {
        SVM<T>.LASVM svm;
        T[] x;
        int[] y;

        PlattScalingTask(SVM<T>.LASVM lasvm, T[] tArr, int[] iArr) {
            this.svm = lasvm;
            this.x = tArr;
            this.y = iArr;
        }

        @Override // java.util.concurrent.Callable
        public SVM<T>.LASVM call() {
            this.svm.trainPlattScaling(this.x, this.y);
            return this.svm;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$ProcessTask.class */
    public class ProcessTask implements Callable<SVM<T>.LASVM> {
        SVM<T>.LASVM svm;

        ProcessTask(SVM<T>.LASVM lasvm) {
            this.svm = lasvm;
        }

        @Override // java.util.concurrent.Callable
        public SVM<T>.LASVM call() {
            this.svm.finish();
            return this.svm;
        }
    }

    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$Trainer.class */
    public static class Trainer<T> extends ClassifierTrainer<T> {
        private MercerKernel<T> kernel;
        private int k;
        private double[] weight;
        private double Cp;
        private double Cn;
        private Multiclass strategy;
        private double tol;
        private int epochs;

        public Trainer(MercerKernel<T> mercerKernel, double d, double d2, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.Cp = 1.0d;
            this.Cn = 1.0d;
            this.strategy = Multiclass.ONE_VS_ONE;
            this.tol = 0.001d;
            this.epochs = 2;
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid postive instance soft margin penalty: " + d);
            }
            if (d2 < 0.0d) {
                throw new IllegalArgumentException("Invalid negative instance soft margin penalty: " + d2);
            }
            this.kernel = mercerKernel;
            this.Cp = d;
            this.Cn = d2;
            this.k = 2;
        }

        public Trainer(MercerKernel<T> mercerKernel, double d, double[] dArr, Multiclass multiclass, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.Cp = 1.0d;
            this.Cn = 1.0d;
            this.strategy = Multiclass.ONE_VS_ONE;
            this.tol = 0.001d;
            this.epochs = 2;
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid soft margin penalty: " + d);
            }
            if (dArr.length < 3) {
                throw new IllegalArgumentException("Invalid number of classes: " + dArr.length);
            }
            this.kernel = mercerKernel;
            this.Cp = d;
            this.Cn = d;
            this.k = dArr.length;
            this.weight = dArr;
            this.strategy = multiclass;
        }

        public Trainer(MercerKernel<T> mercerKernel, double d, int i, Multiclass multiclass, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.Cp = 1.0d;
            this.Cn = 1.0d;
            this.strategy = Multiclass.ONE_VS_ONE;
            this.tol = 0.001d;
            this.epochs = 2;
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid soft margin penalty: " + d);
            }
            if (i < 3) {
                throw new IllegalArgumentException("Invalid number of classes: " + i);
            }
            this.kernel = mercerKernel;
            this.Cp = d;
            this.Cn = d;
            this.k = i;
            this.strategy = multiclass;
        }

        public Trainer(MercerKernel<T> mercerKernel, double d, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.Cp = 1.0d;
            this.Cn = 1.0d;
            this.strategy = Multiclass.ONE_VS_ONE;
            this.tol = 0.001d;
            this.epochs = 2;
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid soft margin penalty: " + d);
            }
            this.kernel = mercerKernel;
            this.Cp = d;
            this.Cn = d;
            this.k = 2;
        }

        public Trainer<T> setNumEpochs(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + i);
            }
            this.epochs = i;
            return this;
        }

        public Trainer<T> setTolerance(double d) {
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Invalid tolerance of convergence test:" + d);
            }
            this.tol = d;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public SVM<T> train(T[] tArr, int[] iArr) {
            return train(tArr, iArr, null);
        }

        public SVM<T> train(T[] tArr, int[] iArr, double[] dArr) {
            SVM<T> svm = this.k == 2 ? new SVM<>(this.kernel, this.Cp, this.Cn, this.interrupt) : this.weight == null ? new SVM<>(this.kernel, this.Cp, this.k, this.strategy, this.interrupt) : new SVM<>(this.kernel, this.Cp, this.weight, this.strategy, this.interrupt);
            svm.setTolerance(this.tol);
            for (int i = 1; i <= this.epochs; i++) {
                svm.learn(tArr, iArr, dArr);
            }
            svm.finish();
            return svm;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/SVM$TrainingTask.class */
    public class TrainingTask implements Callable<SVM<T>.LASVM> {
        SVM<T>.LASVM svm;
        T[] x;
        int[] y;
        double[] weight;

        TrainingTask(SVM<T>.LASVM lasvm, T[] tArr, int[] iArr, double[] dArr) {
            this.svm = lasvm;
            this.x = tArr;
            this.y = iArr;
            this.weight = dArr;
        }

        @Override // java.util.concurrent.Callable
        public SVM<T>.LASVM call() {
            this.svm.learn(this.x, this.y, this.weight);
            return this.svm;
        }
    }

    public SVM(MercerKernel<T> mercerKernel, double d, double d2, TrainingInterrupt trainingInterrupt) {
        super(trainingInterrupt);
        this.strategy = Multiclass.ONE_VS_ONE;
        this.tol = 0.001d;
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid postive instance soft margin penalty: " + d);
        }
        if (d2 < 0.0d) {
            throw new IllegalArgumentException("Invalid negative instance soft margin penalty: " + d2);
        }
        this.kernel = mercerKernel;
        this.k = 2;
        this.svm = new LASVM(d, d2);
    }

    public SVM(MercerKernel<T> mercerKernel, double d, double[] dArr, Multiclass multiclass, TrainingInterrupt trainingInterrupt) {
        super(trainingInterrupt);
        this.strategy = Multiclass.ONE_VS_ONE;
        this.tol = 0.001d;
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid soft margin penalty: " + d);
        }
        if (dArr.length < 3) {
            throw new IllegalArgumentException("Invalid number of classes: " + dArr.length);
        }
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] <= 0.0d) {
                throw new IllegalArgumentException("Invalid class weight: " + dArr[i]);
            }
        }
        this.kernel = mercerKernel;
        this.k = dArr.length;
        this.strategy = multiclass;
        this.wi = dArr;
        if (multiclass == Multiclass.ONE_VS_ALL) {
            this.svms = new ArrayList(this.k);
            for (int i2 = 0; i2 < this.k; i2++) {
                this.svms.add(new LASVM(d, d));
            }
            return;
        }
        this.svms = new ArrayList((this.k * (this.k - 1)) / 2);
        for (int i3 = 0; i3 < this.k; i3++) {
            for (int i4 = i3 + 1; i4 < this.k; i4++) {
                this.svms.add(new LASVM(dArr[i3] * d, dArr[i4] * d));
            }
        }
    }

    public SVM(MercerKernel<T> mercerKernel, double d, int i, Multiclass multiclass, TrainingInterrupt trainingInterrupt) {
        super(trainingInterrupt);
        this.strategy = Multiclass.ONE_VS_ONE;
        this.tol = 0.001d;
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid soft margin penalty: " + d);
        }
        if (i < 3) {
            throw new IllegalArgumentException("Invalid number of classes: " + i);
        }
        this.kernel = mercerKernel;
        this.k = i;
        this.strategy = multiclass;
        if (multiclass == Multiclass.ONE_VS_ALL) {
            this.svms = new ArrayList(i);
            for (int i2 = 0; i2 < i; i2++) {
                this.svms.add(new LASVM(d, d));
            }
            return;
        }
        this.svms = new ArrayList((i * (i - 1)) / 2);
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = i3 + 1; i4 < i; i4++) {
                this.svms.add(new LASVM(d, d));
            }
        }
    }

    public SVM(MercerKernel<T> mercerKernel, double d, TrainingInterrupt trainingInterrupt) {
        this(mercerKernel, d, d, trainingInterrupt);
    }

    public void finish() {
        if (this.k == 2) {
            this.svm.finish();
            return;
        }
        ArrayList arrayList = new ArrayList(this.svms.size());
        Iterator<SVM<T>.LASVM> it = this.svms.iterator();
        while (it.hasNext()) {
            arrayList.add(new ProcessTask(it.next()));
        }
        try {
            MulticoreExecutor.run(arrayList);
        } catch (Exception e) {
            logger.error("Failed to train SVM on multi-core", (Throwable) e);
        }
    }

    public boolean hasPlattScaling() {
        return this.svm.platt != null;
    }

    @Override // smile.classification.OnlineClassifier
    public void learn(T t, int i) {
        learn((SVM<T>) t, i, 1.0d);
    }

    public void learn(T t, int i, double d) {
        if (i < 0 || i >= this.k) {
            throw new IllegalArgumentException("Invalid label");
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid instance weight: " + d);
        }
        if (this.k == 2) {
            if (i == 1) {
                this.svm.process(t, 1, d);
                return;
            } else {
                this.svm.process(t, -1, d);
                return;
            }
        }
        if (this.strategy == Multiclass.ONE_VS_ALL) {
            if (this.wi != null) {
                d *= this.wi[i];
            }
            for (int i2 = 0; i2 < this.k; i2++) {
                if (i == i2) {
                    this.svms.get(i2).process(t, 1, d);
                } else {
                    this.svms.get(i2).process(t, -1, d);
                }
            }
            return;
        }
        int i3 = 0;
        for (int i4 = 0; i4 < this.k; i4++) {
            int i5 = i4 + 1;
            while (i5 < this.k) {
                if (i == i4) {
                    this.svms.get(i3).process(t, 1, d);
                } else if (i == i5) {
                    this.svms.get(i3).process(t, -1, d);
                }
                i5++;
                i3++;
            }
        }
    }

    public void learn(T[] tArr, int[] iArr) {
        learn(tArr, iArr, (double[]) null);
    }

    public void learn(T[] tArr, int[] iArr, double[] dArr) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        if (dArr != null && tArr.length != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and instance weight don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(dArr.length)));
        }
        int min = Math.min(iArr);
        if (min < 0) {
            throw new IllegalArgumentException("Negative class label:" + min);
        }
        int max = Math.max(iArr);
        if (max >= this.k) {
            throw new IllegalArgumentException("Invalid class label:" + max);
        }
        if (this.k == 2) {
            int[] iArr2 = new int[iArr.length];
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] == 1) {
                    iArr2[i] = 1;
                } else {
                    iArr2[i] = -1;
                }
            }
            if (dArr == null) {
                this.svm.learn(tArr, iArr2);
                return;
            } else {
                this.svm.learn(tArr, iArr2, dArr);
                return;
            }
        }
        if (this.strategy == Multiclass.ONE_VS_ALL) {
            ArrayList arrayList = new ArrayList(this.k);
            for (int i2 = 0; i2 < this.k; i2++) {
                int[] iArr3 = new int[iArr.length];
                double[] dArr2 = this.wi == null ? dArr : new double[iArr.length];
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    if (iArr[i3] == i2) {
                        iArr3[i3] = 1;
                    } else {
                        iArr3[i3] = -1;
                    }
                    if (this.wi != null) {
                        dArr2[i3] = this.wi[iArr[i3]];
                        if (dArr != null) {
                            int i4 = i3;
                            dArr2[i4] = dArr2[i4] * dArr[i3];
                        }
                    }
                }
                arrayList.add(new TrainingTask(this.svms.get(i2), tArr, iArr3, dArr2));
            }
            try {
                MulticoreExecutor.run(arrayList);
                return;
            } catch (Exception e) {
                e.printStackTrace();
                return;
            }
        }
        ArrayList arrayList2 = new ArrayList((this.k * (this.k - 1)) / 2);
        int i5 = 0;
        for (int i6 = 0; i6 < this.k; i6++) {
            int i7 = i6 + 1;
            while (i7 < this.k) {
                int i8 = 0;
                for (int i9 = 0; i9 < iArr.length; i9++) {
                    if (iArr[i9] == i6 || iArr[i9] == i7) {
                        i8++;
                    }
                }
                Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), i8);
                int[] iArr4 = new int[i8];
                double[] dArr3 = dArr == null ? null : new double[i8];
                int i10 = 0;
                for (int i11 = 0; i11 < iArr.length; i11++) {
                    if (iArr[i11] == i6) {
                        objArr[i10] = tArr[i11];
                        iArr4[i10] = 1;
                        if (dArr != null) {
                            dArr3[i10] = dArr[i11];
                        }
                        i10++;
                    } else if (iArr[i11] == i7) {
                        objArr[i10] = tArr[i11];
                        iArr4[i10] = -1;
                        if (dArr != null) {
                            dArr3[i10] = dArr[i11];
                        }
                        i10++;
                    }
                }
                arrayList2.add(new TrainingTask(this.svms.get(i5), objArr, iArr4, dArr3));
                i7++;
                i5++;
            }
        }
        try {
            MulticoreExecutor.run(arrayList2);
        } catch (Exception e2) {
            logger.error("Failed to train SVM on multi-core", (Throwable) e2);
        }
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        if (this.k == 2) {
            return this.svm.predict(t) > 0.0d ? 1 : 0;
        }
        if (this.strategy == Multiclass.ONE_VS_ALL) {
            int i = 0;
            double d = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.svms.size(); i2++) {
                double predict = this.svms.get(i2).predict(t);
                if (predict > d) {
                    i = i2;
                    d = predict;
                }
            }
            return i;
        }
        int[] iArr = new int[this.k];
        int i3 = 0;
        for (int i4 = 0; i4 < this.k; i4++) {
            int i5 = i4 + 1;
            while (i5 < this.k) {
                if (this.svms.get(i3).predict(t) > 0.0d) {
                    int i6 = i4;
                    iArr[i6] = iArr[i6] + 1;
                } else {
                    int i7 = i5;
                    iArr[i7] = iArr[i7] + 1;
                }
                i5++;
                i3++;
            }
        }
        int i8 = 0;
        int i9 = 0;
        for (int i10 = 0; i10 < this.k; i10++) {
            if (iArr[i10] > i8) {
                i8 = iArr[i10];
                i9 = i10;
            }
        }
        return i9;
    }

    @Override // smile.classification.SoftClassifier
    public int predict(T t, double[] dArr) {
        if (this.k == 2) {
            if (this.svm.platt == null) {
                throw new UnsupportedOperationException("PlattScaling was not trained yet. Please call SVM.trainPlattScaling() first.");
            }
            double predict = this.svm.predict(t);
            dArr[1] = posterior(this.svm, predict);
            dArr[0] = 1.0d - dArr[1];
            return predict > 0.0d ? 1 : 0;
        }
        if (this.strategy == Multiclass.ONE_VS_ALL) {
            int i = 0;
            double d = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.svms.size(); i2++) {
                SVM<T>.LASVM lasvm = this.svms.get(i2);
                if (lasvm.platt == null) {
                    throw new UnsupportedOperationException("PlattScaling was not trained yet. Please call SVM.trainPlattScaling() first.");
                }
                double predict2 = lasvm.predict(t);
                dArr[i2] = posterior(lasvm, predict2);
                if (predict2 > d) {
                    i = i2;
                    d = predict2;
                }
            }
            Math.unitize1(dArr);
            return i;
        }
        int[] iArr = new int[this.k];
        double[][] dArr2 = new double[this.k][this.k];
        int i3 = 0;
        for (int i4 = 0; i4 < this.k; i4++) {
            int i5 = i4 + 1;
            while (i5 < this.k) {
                SVM<T>.LASVM lasvm2 = this.svms.get(i3);
                if (lasvm2.platt == null) {
                    throw new UnsupportedOperationException("PlattScaling was not trained yet. Please call SVM.trainPlattScaling() first.");
                }
                double predict3 = lasvm2.predict(t);
                dArr2[i4][i5] = posterior(lasvm2, predict3);
                dArr2[i5][i4] = 1.0d - dArr2[i4][i5];
                if (predict3 > 0.0d) {
                    int i6 = i4;
                    iArr[i6] = iArr[i6] + 1;
                } else {
                    int i7 = i5;
                    iArr[i7] = iArr[i7] + 1;
                }
                i5++;
                i3++;
            }
        }
        PlattScaling.multiclass(this.k, dArr2, dArr);
        int i8 = 0;
        int i9 = 0;
        for (int i10 = 0; i10 < this.k; i10++) {
            if (iArr[i10] > i8) {
                i8 = iArr[i10];
                i9 = i10;
            }
        }
        return i9;
    }

    public SVM<T> setTolerance(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance of convergence test:" + d);
        }
        this.tol = d;
        return this;
    }

    public void trainPlattScaling(T[] tArr, int[] iArr) {
        if (this.k == 2) {
            this.svm.trainPlattScaling(tArr, iArr);
            return;
        }
        if (this.strategy == Multiclass.ONE_VS_ALL) {
            ArrayList arrayList = new ArrayList(this.svms.size());
            for (int i = 0; i < this.svms.size(); i++) {
                SVM<T>.LASVM lasvm = this.svms.get(i);
                int length = iArr.length;
                int[] iArr2 = new int[length];
                for (int i2 = 0; i2 < length; i2++) {
                    if (iArr[i2] == i) {
                        iArr2[i2] = 1;
                    } else {
                        iArr2[i2] = -1;
                    }
                }
                arrayList.add(new PlattScalingTask(lasvm, tArr, iArr2));
            }
            try {
                MulticoreExecutor.run(arrayList);
                return;
            } catch (Exception e) {
                logger.error("Failed to train Platt Scaling on multi-core", (Throwable) e);
                return;
            }
        }
        ArrayList arrayList2 = new ArrayList(this.svms.size());
        int i3 = 0;
        for (int i4 = 0; i4 < this.k; i4++) {
            int i5 = i4 + 1;
            while (i5 < this.k) {
                SVM<T>.LASVM lasvm2 = this.svms.get(i3);
                int length2 = iArr.length;
                int[] iArr3 = new int[length2];
                for (int i6 = 0; i6 < length2; i6++) {
                    if (iArr[i6] == i4) {
                        iArr3[i6] = 1;
                    } else {
                        iArr3[i6] = -1;
                    }
                }
                arrayList2.add(new PlattScalingTask(lasvm2, tArr, iArr3));
                i5++;
                i3++;
            }
        }
        try {
            MulticoreExecutor.run(arrayList2);
        } catch (Exception e2) {
            logger.error("Failed to train Platt Scaling on multi-core", (Throwable) e2);
        }
    }

    private double posterior(SVM<T>.LASVM lasvm, double d) {
        return Math.min(Math.max(lasvm.platt.predict(d), 1.0E-7d), 0.9999999d);
    }
}
