package io.citrine.lolo.trees.splits;

import io.citrine.lolo.trees.impurity.MultiImpurityCalculator;
import io.citrine.lolo.trees.impurity.MultiImpurityCalculator$;
import scala.MatchError;
import scala.None$;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.SetLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Set$;
import scala.collection.immutable.Vector;
import scala.collection.mutable.BitSet;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.Ordering$Double$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.util.Random$;

/* compiled from: MultiTaskSplitter.scala */
/* loaded from: input_file:io/citrine/lolo/trees/splits/MultiTaskSplitter$.class */
public final class MultiTaskSplitter$ {
    public static MultiTaskSplitter$ MODULE$;

    static {
        new MultiTaskSplitter$();
    }

    public Tuple2<Split, Object> getBestSplit(Seq<Tuple3<Vector<Object>, Object[], Object>> seq, int i, int i2) {
        ObjectRef create = ObjectRef.create(new NoSplit());
        DoubleRef create2 = DoubleRef.create(Double.MAX_VALUE);
        MultiImpurityCalculator build = MultiImpurityCalculator$.MODULE$.build((Seq) seq.map(tuple3 -> {
            return (Object[]) tuple3._2();
        }, Seq$.MODULE$.canBuildFrom()), (Seq) seq.map(tuple32 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestSplit$2(tuple32));
        }, Seq$.MODULE$.canBuildFrom()));
        double impurity = build.getImpurity();
        Tuple3 tuple33 = (Tuple3) seq.head();
        ((IterableLike) Random$.MODULE$.shuffle(((SeqLike) tuple33._1()).indices(), Seq$.MODULE$.canBuildFrom()).take(i)).foreach(i3 -> {
            Tuple2<Split, Object> bestCategoricalSplit;
            Object apply = ((Vector) tuple33._1()).apply(i3);
            if (apply instanceof Double) {
                bestCategoricalSplit = MODULE$.getBestRealSplit(seq, build, i3, i2);
            } else {
                if (!(apply instanceof Character)) {
                    if (!(apply instanceof Object)) {
                        throw new MatchError(apply);
                    }
                    throw new IllegalArgumentException("Trying to split unknown feature type");
                }
                bestCategoricalSplit = MODULE$.getBestCategoricalSplit(seq, build, i3, i2);
            }
            Tuple2<Split, Object> tuple2 = bestCategoricalSplit;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Tuple2 tuple22 = new Tuple2((Split) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()));
            Split split = (Split) tuple22._1();
            double _2$mcD$sp = tuple22._2$mcD$sp();
            if (_2$mcD$sp < create2.elem) {
                create2.elem = _2$mcD$sp;
                create.elem = split;
            }
        });
        if (create2.elem == Double.MAX_VALUE) {
            return new Tuple2<>(new NoSplit(), BoxesRunTime.boxToDouble(0.0d));
        }
        return new Tuple2<>((Split) create.elem, BoxesRunTime.boxToDouble(impurity - create2.elem));
    }

    public Tuple2<Split, Object> getBestRealSplit(Seq<Tuple3<Vector<Object>, Object[], Object>> seq, MultiImpurityCalculator multiImpurityCalculator, int i, int i2) {
        Seq seq2 = (Seq) ((SeqLike) seq.map(tuple3 -> {
            return new Tuple3(BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(((Vector) tuple3._1()).apply(i))), tuple3._2(), tuple3._3());
        }, Seq$.MODULE$.canBuildFrom())).sortBy(tuple32 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestRealSplit$2(tuple32));
        }, Ordering$Double$.MODULE$);
        Seq seq3 = (Seq) seq2.map(tuple33 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestRealSplit$3(tuple33));
        }, Seq$.MODULE$.canBuildFrom());
        multiImpurityCalculator.reset();
        IndexedSeq indexedSeq = (IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), seq.size() - i2).flatMap(obj -> {
            return $anonfun$getBestRealSplit$4(multiImpurityCalculator, i2, seq2, seq3, BoxesRunTime.unboxToInt(obj));
        }, IndexedSeq$.MODULE$.canBuildFrom());
        if (indexedSeq.isEmpty()) {
            return new Tuple2<>(new NoSplit(), BoxesRunTime.boxToDouble(Double.MAX_VALUE));
        }
        Tuple2 tuple2 = (Tuple2) indexedSeq.minBy(tuple22 -> {
            return BoxesRunTime.boxToDouble(tuple22._2$mcD$sp());
        }, Ordering$Double$.MODULE$);
        return new Tuple2<>(new RealSplit(i, tuple2._1$mcD$sp()), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()));
    }

    public Tuple2<Split, Object> getBestCategoricalSplit(Seq<Tuple3<Vector<Object>, Object[], Object>> seq, MultiImpurityCalculator multiImpurityCalculator, int i, int i2) {
        Seq seq2 = (Seq) seq.map(tuple3 -> {
            return new Tuple3(BoxesRunTime.boxToCharacter(BoxesRunTime.unboxToChar(((Vector) tuple3._1()).apply(i))), tuple3._2(), tuple3._3());
        }, Seq$.MODULE$.canBuildFrom());
        double unboxToDouble = BoxesRunTime.unboxToDouble(((TraversableOnce) seq2.map(tuple32 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$2(tuple32));
        }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
        Map mapValues = seq2.groupBy(tuple33 -> {
            return BoxesRunTime.boxToCharacter($anonfun$getBestCategoricalSplit$3(tuple33));
        }).mapValues(seq3 -> {
            return new Tuple3(BoxesRunTime.boxToDouble(MODULE$.computeImpurity((Seq) seq3.map(tuple34 -> {
                return new Tuple2(tuple34._2(), tuple34._3());
            }, Seq$.MODULE$.canBuildFrom()))), ((TraversableOnce) seq3.map(tuple35 -> {
                return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$6(tuple35));
            }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$), BoxesRunTime.boxToDouble(seq3.size()));
        });
        if (BoxesRunTime.unboxToDouble(((TraversableOnce) ((TraversableLike) mapValues.filter(tuple2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$getBestCategoricalSplit$7(tuple2));
        })).map(tuple22 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$8(tuple22));
        }, Iterable$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$)) / unboxToDouble < 0.25d) {
            return new Tuple2<>(new NoSplit(), BoxesRunTime.boxToDouble(Double.MAX_VALUE));
        }
        Seq seq4 = (Seq) ((TraversableLike) mapValues.toSeq().sortBy(tuple23 -> {
            return BoxesRunTime.boxToDouble($anonfun$getBestCategoricalSplit$9(tuple23));
        }, Ordering$Double$.MODULE$)).map(tuple24 -> {
            return BoxesRunTime.boxToCharacter(tuple24._1$mcC$sp());
        }, Seq$.MODULE$.canBuildFrom());
        IntRef create = IntRef.create(0);
        multiImpurityCalculator.reset();
        IndexedSeq indexedSeq = (IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), seq4.size()).flatMap(obj -> {
            return $anonfun$getBestCategoricalSplit$11(multiImpurityCalculator, i2, seq2, seq4, create, BoxesRunTime.unboxToInt(obj));
        }, IndexedSeq$.MODULE$.canBuildFrom());
        if (indexedSeq.isEmpty()) {
            return new Tuple2<>(new NoSplit(), BoxesRunTime.boxToDouble(Double.MAX_VALUE));
        }
        Tuple2 tuple25 = (Tuple2) indexedSeq.minBy(tuple26 -> {
            return BoxesRunTime.boxToDouble(tuple26._2$mcD$sp());
        }, Ordering$Double$.MODULE$);
        return new Tuple2<>(new CategoricalSplit(i, new BitSet().$plus$plus((GenTraversableOnce) ((SetLike) tuple25._1()).map(obj2 -> {
            return BoxesRunTime.boxToInteger($anonfun$getBestCategoricalSplit$15(BoxesRunTime.unboxToChar(obj2)));
        }, Set$.MODULE$.canBuildFrom()))), BoxesRunTime.boxToDouble(tuple25._2$mcD$sp()));
    }

    public double computeImpurity(Seq<Tuple2<Object[], Object>> seq) {
        return MultiImpurityCalculator$.MODULE$.build((Seq) seq.map(tuple2 -> {
            return (Object[]) tuple2._1();
        }, Seq$.MODULE$.canBuildFrom()), (Seq) seq.map(tuple22 -> {
            return BoxesRunTime.boxToDouble(tuple22._2$mcD$sp());
        }, Seq$.MODULE$.canBuildFrom())).getImpurity();
    }

    public static final /* synthetic */ double $anonfun$getBestSplit$2(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._3());
    }

    public static final /* synthetic */ double $anonfun$getBestRealSplit$2(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._1());
    }

    public static final /* synthetic */ double $anonfun$getBestRealSplit$3(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._1());
    }

    public static final /* synthetic */ Iterable $anonfun$getBestRealSplit$4(MultiImpurityCalculator multiImpurityCalculator, int i, Seq seq, Seq seq2, int i2) {
        double add = multiImpurityCalculator.add((Object[]) ((Tuple3) seq.apply(i2))._2(), BoxesRunTime.unboxToDouble(((Tuple3) seq.apply(i2))._3()));
        if (i2 + 1 < i || Math.abs((BoxesRunTime.unboxToDouble(seq2.apply(i2 + 1)) - BoxesRunTime.unboxToDouble(seq2.apply(i2))) / BoxesRunTime.unboxToDouble(seq2.apply(i2))) <= 1.0E-9d) {
            return Option$.MODULE$.option2Iterable(None$.MODULE$);
        }
        return Option$.MODULE$.option2Iterable(new Some(new Tuple2.mcDD.sp((BoxesRunTime.unboxToDouble(seq2.apply(i2 + 1)) + BoxesRunTime.unboxToDouble(seq2.apply(i2))) / 2.0d, add)));
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$2(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._3());
    }

    public static final /* synthetic */ char $anonfun$getBestCategoricalSplit$3(Tuple3 tuple3) {
        return BoxesRunTime.unboxToChar(tuple3._1());
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$6(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._3());
    }

    public static final /* synthetic */ boolean $anonfun$getBestCategoricalSplit$7(Tuple2 tuple2) {
        return BoxesRunTime.unboxToDouble(((Tuple3) tuple2._2())._3()) > ((double) 1);
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$8(Tuple2 tuple2) {
        return BoxesRunTime.unboxToDouble(((Tuple3) tuple2._2())._2());
    }

    public static final /* synthetic */ double $anonfun$getBestCategoricalSplit$9(Tuple2 tuple2) {
        return BoxesRunTime.unboxToDouble(((Tuple3) tuple2._2())._1());
    }

    public static final /* synthetic */ boolean $anonfun$getBestCategoricalSplit$12(Seq seq, int i, Tuple3 tuple3) {
        return BoxesRunTime.unboxToChar(seq.apply(i)) == BoxesRunTime.unboxToChar(tuple3._1());
    }

    public static final /* synthetic */ void $anonfun$getBestCategoricalSplit$13(MultiImpurityCalculator multiImpurityCalculator, IntRef intRef, Tuple3 tuple3) {
        multiImpurityCalculator.add((Object[]) tuple3._2(), BoxesRunTime.unboxToDouble(tuple3._3()));
        intRef.elem++;
    }

    public static final /* synthetic */ Iterable $anonfun$getBestCategoricalSplit$11(MultiImpurityCalculator multiImpurityCalculator, int i, Seq seq, Seq seq2, IntRef intRef, int i2) {
        ((TraversableLike) seq.filter(tuple3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$getBestCategoricalSplit$12(seq2, i2, tuple3));
        })).map(tuple32 -> {
            $anonfun$getBestCategoricalSplit$13(multiImpurityCalculator, intRef, tuple32);
            return BoxedUnit.UNIT;
        }, Seq$.MODULE$.canBuildFrom());
        double impurity = multiImpurityCalculator.getImpurity();
        if (intRef.elem < i || seq.size() - intRef.elem < i) {
            return Option$.MODULE$.option2Iterable(None$.MODULE$);
        }
        return Option$.MODULE$.option2Iterable(new Some(new Tuple2(((TraversableOnce) seq2.take(i2 + 1)).toSet(), BoxesRunTime.boxToDouble(impurity))));
    }

    public static final /* synthetic */ int $anonfun$getBestCategoricalSplit$15(char c) {
        return c;
    }

    private MultiTaskSplitter$() {
        MODULE$ = this;
    }
}
