package io.citrine.lolo.trees.splits;

import io.citrine.lolo.trees.impurity.GiniCalculator;
import io.citrine.lolo.trees.impurity.GiniCalculator$;
import scala.Double$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable$;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Set;
import scala.collection.immutable.Set$;
import scala.collection.immutable.Vector;
import scala.collection.mutable.BitSet;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.Ordering$Double$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.util.Random$;

/* compiled from: ClassificationSplitter.scala */
/* loaded from: input_file:io/citrine/lolo/trees/splits/ClassificationSplitter$.class */
public final class ClassificationSplitter$ {
    public static ClassificationSplitter$ MODULE$;

    static {
        new ClassificationSplitter$();
    }

    public Tuple2<Split, Object> getBestSplit(Seq<Tuple3<Vector<Object>, Object, Object>> seq, int i, int i2, boolean z) {
        ObjectRef create = ObjectRef.create(new NoSplit());
        DoubleRef create2 = DoubleRef.create(Double.MAX_VALUE);
        GiniCalculator build = GiniCalculator$.MODULE$.build((Seq) seq.map(tuple3 -> {
            return new Tuple2.mcCD.sp(BoxesRunTime.unboxToChar(tuple3._2()), BoxesRunTime.unboxToDouble(tuple3._3()));
        }, Seq$.MODULE$.canBuildFrom()));
        double impurity = build.getImpurity();
        Tuple3 tuple32 = (Tuple3) seq.head();
        ((IterableLike) Random$.MODULE$.shuffle(((SeqLike) tuple32._1()).indices(), Seq$.MODULE$.canBuildFrom()).take(i)).foreach(i3 -> {
            Tuple2 bestCategoricalSplit;
            Object apply = ((Vector) tuple32._1()).apply(i3);
            if (apply instanceof Double) {
                bestCategoricalSplit = MODULE$.getBestRealSplit(seq, build, i3, i2, z);
            } else {
                if (!(apply instanceof Character)) {
                    if (!(apply instanceof Object)) {
                        throw new MatchError(apply);
                    }
                    throw new IllegalArgumentException("Trying to split unknown feature type");
                }
                bestCategoricalSplit = MODULE$.getBestCategoricalSplit(seq, build, i3, i2);
            }
            Tuple2 tuple2 = bestCategoricalSplit;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Tuple2 tuple22 = new Tuple2((Split) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()));
            Split split = (Split) tuple22._1();
            double _2$mcD$sp = tuple22._2$mcD$sp();
            if (_2$mcD$sp < create2.elem) {
                create2.elem = _2$mcD$sp;
                create.elem = split;
            }
        });
        if (create2.elem == Double.MAX_VALUE) {
            return new Tuple2<>(new NoSplit(), BoxesRunTime.boxToDouble(0.0d));
        }
        return new Tuple2<>((Split) create.elem, BoxesRunTime.boxToDouble(impurity - create2.elem));
    }

    public boolean getBestSplit$default$4() {
        return false;
    }

    public Tuple2<RealSplit, Object> getBestRealSplit(Seq<Tuple3<Vector<Object>, Object, Object>> seq, GiniCalculator giniCalculator, int i, int i2, boolean z) {
        Seq seq2 = (Seq) ((SeqLike) seq.map(tuple3 -> {
            return new Tuple3(BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(((Vector) tuple3._1()).apply(i))), tuple3._2(), tuple3._3());
        }, Seq$.MODULE$.canBuildFrom())).sortBy(tuple32 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestRealSplit$2(tuple32));
        }, Ordering$Double$.MODULE$);
        DoubleRef create = DoubleRef.create(Double.MAX_VALUE);
        DoubleRef create2 = DoubleRef.create(Double$.MODULE$.MinValue());
        giniCalculator.reset();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), seq.size() - i2).foreach$mVc$sp(i3 -> {
            double add = giniCalculator.add(BoxesRunTime.unboxToChar(((Tuple3) seq2.apply(i3))._2()), BoxesRunTime.unboxToDouble(((Tuple3) seq2.apply(i3))._3()));
            if (add >= create.elem || i3 + 1 < i2 || Math.abs((BoxesRunTime.unboxToDouble(((Tuple3) seq2.apply(i3 + 1))._1()) - BoxesRunTime.unboxToDouble(((Tuple3) seq2.apply(i3))._1())) / BoxesRunTime.unboxToDouble(((Tuple3) seq2.apply(i3))._1())) <= 1.0E-9d) {
                return;
            }
            create.elem = add;
            double unboxToDouble = BoxesRunTime.unboxToDouble(((Tuple3) seq2.apply(i3 + 1))._1());
            double unboxToDouble2 = BoxesRunTime.unboxToDouble(((Tuple3) seq2.apply(i3))._1());
            create2.elem = z ? ((unboxToDouble - unboxToDouble2) * Random$.MODULE$.nextDouble()) + unboxToDouble2 : (unboxToDouble + unboxToDouble2) / 2.0d;
        });
        return new Tuple2<>(new RealSplit(i, create2.elem), BoxesRunTime.boxToDouble(create.elem));
    }

    public boolean getBestRealSplit$default$5() {
        return false;
    }

    public Tuple2<CategoricalSplit, Object> getBestCategoricalSplit(Seq<Tuple3<Vector<Object>, Object, Object>> seq, GiniCalculator giniCalculator, int i, int i2) {
        Seq seq2 = (Seq) seq.map(tuple3 -> {
            return new Tuple3(BoxesRunTime.boxToCharacter(BoxesRunTime.unboxToChar(((Vector) tuple3._1()).apply(i))), tuple3._2(), tuple3._3());
        }, Seq$.MODULE$.canBuildFrom());
        Map mapValues = seq2.groupBy(tuple32 -> {
            return BoxesRunTime.boxToCharacter($anonfun$getBestCategoricalSplit$2(tuple32));
        }).mapValues(seq3 -> {
            Map mapValues2 = seq3.groupBy(tuple33 -> {
                return BoxesRunTime.boxToCharacter($anonfun$getBestCategoricalSplit$4(tuple33));
            }).mapValues(seq3 -> {
                return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$5(seq3));
            });
            return new Tuple3(mapValues2, BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(((TraversableOnce) mapValues2.values().map(d -> {
                return Math.pow(d, 2.0d);
            }, Iterable$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$)) / Math.pow(BoxesRunTime.unboxToDouble(mapValues2.values().sum(Numeric$DoubleIsFractional$.MODULE$)), 2.0d)), BoxesRunTime.boxToInteger(seq3.size()));
        });
        Seq seq4 = (Seq) ((TraversableLike) mapValues.toSeq().sortBy(tuple2 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$8(tuple2));
        }, Ordering$Double$.MODULE$)).map(tuple22 -> {
            return BoxesRunTime.boxToCharacter(tuple22._1$mcC$sp());
        }, Seq$.MODULE$.canBuildFrom());
        IntRef create = IntRef.create(0);
        DoubleRef create2 = DoubleRef.create(Double.MAX_VALUE);
        ObjectRef create3 = ObjectRef.create(Predef$.MODULE$.Set().empty());
        giniCalculator.reset();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), seq4.size() - 1).foreach$mVc$sp(i3 -> {
            Tuple3 tuple33 = (Tuple3) mapValues.apply(seq4.apply(i3));
            ((Map) tuple33._1()).foreach(tuple23 -> {
                return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$11(giniCalculator, tuple23));
            });
            create.elem += BoxesRunTime.unboxToInt(tuple33._3());
            double impurity = giniCalculator.getImpurity();
            if (impurity >= create2.elem || create.elem < i2 || seq2.size() - create.elem < i2) {
                return;
            }
            create2.elem = impurity;
            create3.elem = ((TraversableOnce) seq4.slice(0, i3 + 1)).toSet();
        });
        return new Tuple2<>(new CategoricalSplit(i, new BitSet().$plus$plus((GenTraversableOnce) ((Set) create3.elem).map(obj -> {
            return BoxesRunTime.boxToInteger($anonfun$getBestCategoricalSplit$12(BoxesRunTime.unboxToChar(obj)));
        }, Set$.MODULE$.canBuildFrom()))), BoxesRunTime.boxToDouble(create2.elem));
    }

    public static final /* synthetic */ double $anonfun$getBestRealSplit$2(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._1());
    }

    public static final /* synthetic */ char $anonfun$getBestCategoricalSplit$2(Tuple3 tuple3) {
        return BoxesRunTime.unboxToChar(tuple3._1());
    }

    public static final /* synthetic */ char $anonfun$getBestCategoricalSplit$4(Tuple3 tuple3) {
        return BoxesRunTime.unboxToChar(tuple3._2());
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$6(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._3());
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$5(Seq seq) {
        return BoxesRunTime.unboxToDouble(((TraversableOnce) seq.map(tuple3 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$6(tuple3));
        }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$8(Tuple2 tuple2) {
        return BoxesRunTime.unboxToDouble(((Tuple3) tuple2._2())._2());
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$11(GiniCalculator giniCalculator, Tuple2 tuple2) {
        if (tuple2 != null) {
            return giniCalculator.add(tuple2._1$mcC$sp(), tuple2._2$mcD$sp());
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ int $anonfun$getBestCategoricalSplit$12(char c) {
        return c;
    }

    private ClassificationSplitter$() {
        MODULE$ = this;
    }
}
