package smile.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.MulticoreExecutor;
import smile.util.SmileUtils;
import smile.validation.RMSE;
import smile.validation.RegressionMeasure;

/* JADX WARN: Classes with same name are omitted:
  input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/regression/RandomForest.class
 */
/* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/regression/RandomForest.class */
public class RandomForest implements Regression<double[]>, Serializable {
    private static final long serialVersionUID = 1;
    private List<RegressionTree> trees;
    private double error;
    private double[] importance;

    /* JADX WARN: Classes with same name are omitted:
      input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/regression/RandomForest$Trainer.class
     */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/regression/RandomForest$Trainer.class */
    public static class Trainer extends RegressionTrainer<double[]> {
        private int ntrees;
        private int mtry;
        private int nodeSize;
        private int maxNodes;
        private double subsample;

        public Trainer(int i) {
            this.ntrees = 500;
            this.mtry = -1;
            this.nodeSize = 5;
            this.maxNodes = 100;
            this.subsample = 1.0d;
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
        }

        public Trainer(Attribute[] attributeArr, int i) {
            super(attributeArr);
            this.ntrees = 500;
            this.mtry = -1;
            this.nodeSize = 5;
            this.maxNodes = 100;
            this.subsample = 1.0d;
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
        }

        public Trainer setNumTrees(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
            return this;
        }

        public Trainer setNumRandomFeatures(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + i);
            }
            this.mtry = i;
            return this;
        }

        public Trainer setMaxNodes(int i) {
            if (i < 2) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i);
            }
            this.maxNodes = i;
            return this;
        }

        public Trainer setNodeSize(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i);
            }
            this.nodeSize = i;
            return this;
        }

        public Trainer setSamplingRates(double d) {
            if (d <= 0.0d || d > 1.0d) {
                throw new IllegalArgumentException("Invalid sampling rate: " + d);
            }
            this.subsample = d;
            return this;
        }

        @Override // smile.regression.RegressionTrainer
        public RandomForest train(double[][] dArr, double[] dArr2) {
            return new RandomForest(this.attributes, dArr, dArr2, this.ntrees, this.maxNodes, this.nodeSize, this.mtry, this.subsample);
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      input_file:BOOT-INF/classes/libarx-3.7.1.jar:smile/regression/RandomForest$TrainingTask.class
     */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.7.1.jar:smile/regression/RandomForest$TrainingTask.class */
    static class TrainingTask implements Callable<RegressionTree> {
        Attribute[] attributes;
        double[][] x;
        double[] y;
        int[][] order;
        int mtry;
        int nodeSize;
        int maxNodes;
        double subsample;
        double[] prediction;
        int[] oob;

        TrainingTask(Attribute[] attributeArr, double[][] dArr, double[] dArr2, int i, int i2, int i3, double d, int[][] iArr, double[] dArr3, int[] iArr2) {
            this.nodeSize = 5;
            this.maxNodes = 100;
            this.subsample = 1.0d;
            this.attributes = attributeArr;
            this.x = dArr;
            this.y = dArr2;
            this.order = iArr;
            this.mtry = i3;
            this.nodeSize = i2;
            this.maxNodes = i;
            this.subsample = d;
            this.prediction = dArr3;
            this.oob = iArr2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public RegressionTree call() {
            int length = this.x.length;
            int[] iArr = new int[length];
            if (this.subsample == 1.0d) {
                for (int i = 0; i < length; i++) {
                    int randomInt = Math.randomInt(length);
                    iArr[randomInt] = iArr[randomInt] + 1;
                }
            } else {
                int[] iArr2 = new int[length];
                for (int i2 = 0; i2 < length; i2++) {
                    iArr2[i2] = i2;
                }
                Math.permutate(iArr2);
                int round = (int) Math.round(length * this.subsample);
                for (int i3 = 0; i3 < round; i3++) {
                    int i4 = iArr2[i3];
                    iArr[i4] = iArr[i4] + 1;
                }
            }
            RegressionTree regressionTree = new RegressionTree(this.attributes, this.x, this.y, this.maxNodes, this.nodeSize, this.mtry, this.order, iArr, null);
            for (int i5 = 0; i5 < length; i5++) {
                if (iArr[i5] == 0) {
                    double predict = regressionTree.predict(this.x[i5]);
                    synchronized (this.x[i5]) {
                        double[] dArr = this.prediction;
                        int i6 = i5;
                        dArr[i6] = dArr[i6] + predict;
                        int[] iArr3 = this.oob;
                        int i7 = i5;
                        iArr3[i7] = iArr3[i7] + 1;
                    }
                }
            }
            return regressionTree;
        }
    }

    public RandomForest(double[][] dArr, double[] dArr2, int i) {
        this(null, dArr, dArr2, i);
    }

    public RandomForest(double[][] dArr, double[] dArr2, int i, int i2, int i3, int i4) {
        this(null, dArr, dArr2, i, i2, i3, i4);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, double[] dArr2, int i) {
        this(attributeArr, dArr, dArr2, i, 100);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, double[] dArr2, int i, int i2) {
        this(attributeArr, dArr, dArr2, i, i2, 5);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, double[] dArr2, int i, int i2, int i3) {
        this(attributeArr, dArr, dArr2, i, i2, i3, dArr[0].length / 3);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, double[] dArr2, int i, int i2, int i3, int i4) {
        this(attributeArr, dArr, dArr2, i, i2, i3, i4, 1.0d);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, double[] dArr2, int i, int i2, int i3, int i4, double d) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(dArr2.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (i4 < 1 || i4 > dArr[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + i4);
        }
        if (i3 < 2) {
            throw new IllegalArgumentException("Invalid minimum size of leaves: " + i3);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("Invalid maximum number of leaves: " + i2);
        }
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling rate: " + d);
        }
        if (attributeArr == null) {
            int length = dArr[0].length;
            attributeArr = new Attribute[length];
            for (int i5 = 0; i5 < length; i5++) {
                attributeArr[i5] = new NumericAttribute("V" + (i5 + 1));
            }
        }
        int length2 = dArr.length;
        double[] dArr3 = new double[length2];
        int[] iArr = new int[length2];
        int[][] sort = SmileUtils.sort(attributeArr, dArr);
        ArrayList arrayList = new ArrayList();
        for (int i6 = 0; i6 < i; i6++) {
            arrayList.add(new TrainingTask(attributeArr, dArr, dArr2, i2, i3, i4, d, sort, dArr3, iArr));
        }
        try {
            this.trees = MulticoreExecutor.run(arrayList);
        } catch (Exception e) {
            e.printStackTrace();
            this.trees = new ArrayList(i);
            for (int i7 = 0; i7 < i; i7++) {
                this.trees.add(((TrainingTask) arrayList.get(i7)).call());
            }
        }
        int i8 = 0;
        for (int i9 = 0; i9 < length2; i9++) {
            if (iArr[i9] > 0) {
                i8++;
                this.error += Math.sqr((dArr3[i9] / iArr[i9]) - dArr2[i9]);
            }
        }
        if (i8 > 0) {
            this.error = Math.sqrt(this.error / i8);
        }
        this.importance = new double[attributeArr.length];
        Iterator<RegressionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            double[] importance = it.next().importance();
            for (int i10 = 0; i10 < importance.length; i10++) {
                double[] dArr4 = this.importance;
                int i11 = i10;
                dArr4[i11] = dArr4[i11] + importance[i10];
            }
        }
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public void trim(int i) {
        if (i > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(this.trees.get(i2));
        }
        this.trees = arrayList;
    }

    @Override // smile.regression.Regression
    public double predict(double[] dArr) {
        double d = 0.0d;
        Iterator<RegressionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            d += it.next().predict(dArr);
        }
        return d / this.trees.size();
    }

    public double[] test(double[][] dArr, double[] dArr2) {
        int size = this.trees.size();
        double[] dArr3 = new double[size];
        int length = dArr.length;
        double[] dArr4 = new double[length];
        double[] dArr5 = new double[length];
        RMSE rmse = new RMSE();
        int i = 0;
        int i2 = 1;
        while (i < size) {
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = i3;
                dArr4[i4] = dArr4[i4] + this.trees.get(i).predict(dArr[i3]);
                dArr5[i3] = dArr4[i3] / i2;
            }
            dArr3[i] = rmse.measure(dArr2, dArr5);
            i++;
            i2++;
        }
        return dArr3;
    }

    public double[][] test(double[][] dArr, double[] dArr2, RegressionMeasure[] regressionMeasureArr) {
        int size = this.trees.size();
        int length = regressionMeasureArr.length;
        double[][] dArr3 = new double[size][length];
        int length2 = dArr.length;
        double[] dArr4 = new double[length2];
        double[] dArr5 = new double[length2];
        int i = 0;
        int i2 = 1;
        while (i < size) {
            for (int i3 = 0; i3 < length2; i3++) {
                int i4 = i3;
                dArr4[i4] = dArr4[i4] + this.trees.get(i).predict(dArr[i3]);
                dArr5[i3] = dArr4[i3] / i2;
            }
            for (int i5 = 0; i5 < length; i5++) {
                dArr3[i][i5] = regressionMeasureArr[i5].measure(dArr2, dArr5);
            }
            i++;
            i2++;
        }
        return dArr3;
    }

    public RegressionTree[] getTrees() {
        return (RegressionTree[]) this.trees.toArray(new RegressionTree[this.trees.size()]);
    }
}
