package org.apache.mahout.classifier.df.builder;

import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.Random;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.data.conditions.Condition;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.classifier.df.node.NumericalNode;
import org.apache.mahout.classifier.df.split.IgSplit;
import org.apache.mahout.classifier.df.split.OptIgSplit;
import org.apache.mahout.classifier.df.split.RegressionSplit;
import org.apache.mahout.classifier.df.split.Split;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:libarx-3.7.1.jar:org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.class */
public class DecisionTreeBuilder implements TreeBuilder {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DecisionTreeBuilder.class);
    private static final int[] NO_ATTRIBUTES = new int[0];
    private static final double EPSILON = 1.0E-6d;
    private boolean[] selected;
    private int m;
    private IgSplit igSplit;
    private Data fullSet;
    private boolean complemented = true;
    private double minSplitNum = 2.0d;
    private double minVarianceProportion = 0.001d;
    private double minVariance = Double.NaN;

    public void setM(int i) {
        this.m = i;
    }

    public void setIgSplit(IgSplit igSplit) {
        this.igSplit = igSplit;
    }

    public void setComplemented(boolean z) {
        this.complemented = z;
    }

    public void setMinSplitNum(int i) {
        this.minSplitNum = i;
    }

    public void setMinVarianceProportion(double d) {
        this.minVarianceProportion = d;
    }

    @Override // org.apache.mahout.classifier.df.builder.TreeBuilder
    public Node build(Random random, Data data) {
        Node categoricalNode;
        if (this.selected == null) {
            this.selected = new boolean[data.getDataset().nbAttributes()];
            this.selected[data.getDataset().getLabelId()] = true;
        }
        if (this.m == 0) {
            double nbAttributes = data.getDataset().nbAttributes() - 1;
            if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
                this.m = (int) Math.ceil(nbAttributes / 3.0d);
            } else {
                this.m = (int) Math.ceil(Math.sqrt(nbAttributes));
            }
        }
        if (data.isEmpty()) {
            return new Leaf(Double.NaN);
        }
        double d = 0.0d;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
            double d2 = 0.0d;
            for (int i = 0; i < data.size(); i++) {
                double label = data.getDataset().getLabel(data.get(i));
                d += label;
                d2 += label * label;
            }
            double size = d2 - ((d * d) / data.size());
            if (Double.compare(this.minVariance, Double.NaN) == 0) {
                this.minVariance = (size / data.size()) * this.minVarianceProportion;
                log.debug("minVariance:{}", Double.valueOf(this.minVariance));
            }
            if (size / data.size() < this.minVariance) {
                log.debug("variance({}) < minVariance({}) Leaf({})", Double.valueOf(size / data.size()), Double.valueOf(this.minVariance), Double.valueOf(d / data.size()));
                return new Leaf(d / data.size());
            }
        } else {
            if (isIdentical(data)) {
                return new Leaf(data.majorityLabel(random));
            }
            if (data.identicalLabel()) {
                return new Leaf(data.getDataset().getLabel(data.get(0)));
            }
        }
        if (this.fullSet == null) {
            this.fullSet = data;
        }
        int[] randomAttributes = randomAttributes(random, this.selected, this.m);
        if (randomAttributes == null || randomAttributes.length == 0) {
            double size2 = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? d / data.size() : data.majorityLabel(random);
            log.warn("attribute which can be selected is not found Leaf({})", Double.valueOf(size2));
            return new Leaf(size2);
        }
        if (this.igSplit == null) {
            if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
                this.igSplit = new RegressionSplit();
            } else {
                this.igSplit = new OptIgSplit();
            }
        }
        Split split = null;
        for (int i2 : randomAttributes) {
            Split computeSplit = this.igSplit.computeSplit(data, i2);
            if (split == null || split.getIg() < computeSplit.getIg()) {
                split = computeSplit;
            }
        }
        if (split.getIg() < 1.0E-6d) {
            double size3 = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? d / data.size() : data.majorityLabel(random);
            log.debug("ig is near to zero Leaf({})", Double.valueOf(size3));
            return new Leaf(size3);
        }
        log.debug("best split attr:{}, split:{}, ig:{}", Integer.valueOf(split.getAttr()), Double.valueOf(split.getSplit()), Double.valueOf(split.getIg()));
        boolean z = this.selected[split.getAttr()];
        if (z) {
            log.warn("attribute {} already selected in a parent node", Integer.valueOf(split.getAttr()));
        }
        if (data.getDataset().isNumerical(split.getAttr())) {
            boolean[] zArr = null;
            Data subset = data.subset(Condition.lesser(split.getAttr(), split.getSplit()));
            Data subset2 = data.subset(Condition.greaterOrEquals(split.getAttr(), split.getSplit()));
            if (subset.isEmpty() || subset2.isEmpty()) {
                this.selected[split.getAttr()] = true;
            } else {
                zArr = this.selected;
                this.selected = cloneCategoricalAttributes(data.getDataset(), this.selected);
            }
            if (subset.size() < this.minSplitNum || subset2.size() < this.minSplitNum) {
                double size4 = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? d / data.size() : data.majorityLabel(random);
                log.debug("branch is not split Leaf({})", Double.valueOf(size4));
                return new Leaf(size4);
            }
            Node build = build(random, subset);
            Node build2 = build(random, subset2);
            if (zArr != null) {
                this.selected = zArr;
            } else {
                this.selected[split.getAttr()] = z;
            }
            categoricalNode = new NumericalNode(split.getAttr(), split.getSplit(), build, build2);
        } else {
            double[] values = data.values(split.getAttr());
            HashSet hashSet = null;
            if (this.complemented) {
                hashSet = Sets.newHashSet();
                for (double d3 : values) {
                    hashSet.add(Double.valueOf(d3));
                }
                values = this.fullSet.values(split.getAttr());
            }
            int i3 = 0;
            Data[] dataArr = new Data[values.length];
            for (int i4 = 0; i4 < values.length; i4++) {
                if (!this.complemented || hashSet.contains(Double.valueOf(values[i4]))) {
                    dataArr[i4] = data.subset(Condition.equals(split.getAttr(), values[i4]));
                    if (dataArr[i4].size() >= this.minSplitNum) {
                        i3++;
                    }
                }
            }
            if (i3 < 2) {
                double size5 = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? d / data.size() : data.majorityLabel(random);
                log.debug("branch is not split Leaf({})", Double.valueOf(size5));
                return new Leaf(size5);
            }
            this.selected[split.getAttr()] = true;
            Node[] nodeArr = new Node[values.length];
            for (int i5 = 0; i5 < values.length; i5++) {
                if (!this.complemented || (hashSet != null && hashSet.contains(Double.valueOf(values[i5])))) {
                    nodeArr[i5] = build(random, dataArr[i5]);
                } else {
                    double size6 = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? d / data.size() : data.majorityLabel(random);
                    log.debug("complemented Leaf({})", Double.valueOf(size6));
                    nodeArr[i5] = new Leaf(size6);
                }
            }
            this.selected[split.getAttr()] = z;
            categoricalNode = new CategoricalNode(split.getAttr(), values, nodeArr);
        }
        return categoricalNode;
    }

    private boolean isIdentical(Data data) {
        if (data.isEmpty()) {
            return true;
        }
        Instance instance = data.get(0);
        for (int i = 0; i < this.selected.length; i++) {
            if (!this.selected[i]) {
                for (int i2 = 1; i2 < data.size(); i2++) {
                    if (data.get(i2).get(i) != instance.get(i)) {
                        return false;
                    }
                }
            }
        }
        return true;
    }

    private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] zArr) {
        boolean[] zArr2 = new boolean[zArr.length];
        for (int i = 0; i < zArr.length; i++) {
            zArr2[i] = !dataset.isNumerical(i) && zArr[i];
        }
        zArr2[dataset.getLabelId()] = true;
        return zArr2;
    }

    private static int[] randomAttributes(Random random, boolean[] zArr, int i) {
        int[] iArr;
        int nextInt;
        int i2 = 0;
        for (boolean z : zArr) {
            if (!z) {
                i2++;
            }
        }
        if (i2 == 0) {
            log.warn("All attributes are selected !");
            return NO_ATTRIBUTES;
        }
        if (i2 <= i) {
            iArr = new int[i2];
            int i3 = 0;
            for (int i4 = 0; i4 < zArr.length; i4++) {
                if (!zArr[i4]) {
                    int i5 = i3;
                    i3++;
                    iArr[i5] = i4;
                }
            }
        } else {
            iArr = new int[i];
            for (int i6 = 0; i6 < i; i6++) {
                do {
                    nextInt = random.nextInt(zArr.length);
                } while (zArr[nextInt]);
                iArr[i6] = nextInt;
                zArr[nextInt] = true;
            }
            for (int i7 : iArr) {
                zArr[i7] = false;
            }
        }
        return iArr;
    }
}
