package io.citrine.lolo.trees.multitask;

import io.citrine.lolo.Model;
import io.citrine.lolo.MultiTaskLearner;
import io.citrine.lolo.TrainingResult;
import io.citrine.lolo.encoders.CategoricalEncoder;
import io.citrine.lolo.encoders.CategoricalEncoder$;
import io.citrine.lolo.trees.ModelNode;
import io.citrine.lolo.trees.classification.ClassificationTree;
import io.citrine.lolo.trees.regression.RegressionTree;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple3;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Vector;
import scala.collection.immutable.Vector$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: MultiTaskTree.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00153A!\u0001\u0002\u0001\u001b\t!R*\u001e7uSR\u000b7o\u001b+sK\u0016dU-\u0019:oKJT!a\u0001\u0003\u0002\u00135,H\u000e^5uCN\\'BA\u0003\u0007\u0003\u0015!(/Z3t\u0015\t9\u0001\"\u0001\u0003m_2|'BA\u0005\u000b\u0003\u001d\u0019\u0017\u000e\u001e:j]\u0016T\u0011aC\u0001\u0003S>\u001c\u0001aE\u0002\u0001\u001dQ\u0001\"a\u0004\n\u000e\u0003AQ\u0011!E\u0001\u0006g\u000e\fG.Y\u0005\u0003'A\u0011a!\u00118z%\u00164\u0007CA\u000b\u0017\u001b\u00051\u0011BA\f\u0007\u0005AiU\u000f\u001c;j)\u0006\u001c8\u000eT3be:,'\u000fC\u0003\u001a\u0001\u0011\u0005!$\u0001\u0004=S:LGO\u0010\u000b\u00027A\u0011A\u0004A\u0007\u0002\u0005!)a\u0004\u0001C!?\u0005)AO]1j]R!\u0001e\f\u001d=!\r\t\u0013\u0006\f\b\u0003E\u001dr!a\t\u0014\u000e\u0003\u0011R!!\n\u0007\u0002\rq\u0012xn\u001c;?\u0013\u0005\t\u0012B\u0001\u0015\u0011\u0003\u001d\u0001\u0018mY6bO\u0016L!AK\u0016\u0003\u0007M+\u0017O\u0003\u0002)!A\u0011Q#L\u0005\u0003]\u0019\u0011a\u0002\u0016:bS:Lgn\u001a*fgVdG\u000fC\u00031;\u0001\u0007\u0011'\u0001\u0004j]B,Ho\u001d\t\u0004C%\u0012\u0004cA\u00114k%\u0011Ag\u000b\u0002\u0007-\u0016\u001cGo\u001c:\u0011\u0005=1\u0014BA\u001c\u0011\u0005\r\te.\u001f\u0005\u0006su\u0001\rAO\u0001\u0007Y\u0006\u0014W\r\\:\u0011\u0007\u0005J3\bE\u0002\"SUBq!P\u000f\u0011\u0002\u0003\u0007a(A\u0004xK&<\u0007\u000e^:\u0011\u0007=y\u0014)\u0003\u0002A!\t1q\n\u001d;j_:\u00042!I\u0015C!\ty1)\u0003\u0002E!\t1Ai\\;cY\u0016\u0004")
/* loaded from: input_file:io/citrine/lolo/trees/multitask/MultiTaskTreeLearner.class */
public class MultiTaskTreeLearner implements MultiTaskLearner {
    private Map<String, Object> hypers;

    @Override // io.citrine.lolo.MultiTaskLearner
    public MultiTaskLearner setHypers(Map<String, Object> map) {
        MultiTaskLearner hypers;
        hypers = setHypers(map);
        return hypers;
    }

    @Override // io.citrine.lolo.MultiTaskLearner
    public MultiTaskLearner setHyper(String str, Object obj) {
        MultiTaskLearner hyper;
        hyper = setHyper(str, obj);
        return hyper;
    }

    @Override // io.citrine.lolo.MultiTaskLearner
    public Map<String, Object> getHypers() {
        Map<String, Object> hypers;
        hypers = getHypers();
        return hypers;
    }

    @Override // io.citrine.lolo.MultiTaskLearner
    public Option<Seq<Object>> train$default$3() {
        Option<Seq<Object>> train$default$3;
        train$default$3 = train$default$3();
        return train$default$3;
    }

    @Override // io.citrine.lolo.MultiTaskLearner
    public Map<String, Object> hypers() {
        return this.hypers;
    }

    @Override // io.citrine.lolo.MultiTaskLearner
    public void hypers_$eq(Map<String, Object> map) {
        this.hypers = map;
    }

    @Override // io.citrine.lolo.MultiTaskLearner
    public Seq<TrainingResult> train(Seq<Vector<Object>> seq, Seq<Seq<Object>> seq2, Option<Seq<Object>> option) {
        Vector transpose = seq2.toVector().transpose(Predef$.MODULE$.$conforms());
        Seq seq3 = (Seq) ((TraversableLike) ((Vector) seq.head()).zipWithIndex(Vector$.MODULE$.canBuildFrom())).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Object _1 = tuple2._1();
            int _2$mcI$sp = tuple2._2$mcI$sp();
            return _1 instanceof Double ? None$.MODULE$ : new Some(CategoricalEncoder$.MODULE$.buildEncoder((Seq) seq.map(vector -> {
                return vector.apply(_2$mcI$sp);
            }, Seq$.MODULE$.canBuildFrom())));
        }, Vector$.MODULE$.canBuildFrom());
        Seq seq4 = (Seq) seq.map(vector -> {
            return CategoricalEncoder$.MODULE$.encodeInput(vector, seq3);
        }, Seq$.MODULE$.canBuildFrom());
        Seq seq5 = (Seq) ((TraversableLike) ((Seq) seq2.map(seq6 -> {
            return seq6.head();
        }, Seq$.MODULE$.canBuildFrom())).zipWithIndex(Seq$.MODULE$.canBuildFrom())).map(tuple22 -> {
            if (tuple22 != null) {
                return tuple22._1() instanceof Double ? None$.MODULE$ : new Some(CategoricalEncoder$.MODULE$.buildEncoder((Seq) ((TraversableLike) seq2.apply(tuple22._2$mcI$sp())).filterNot(obj -> {
                    return BoxesRunTime.boxToBoolean($anonfun$train$6(obj));
                })));
            }
            throw new MatchError(tuple22);
        }, Seq$.MODULE$.canBuildFrom());
        Vector vector2 = (Vector) transpose.map(vector3 -> {
            return CategoricalEncoder$.MODULE$.encodeInput(vector3, seq5);
        }, Vector$.MODULE$.canBuildFrom());
        MultiTaskTrainingNode multiTaskTrainingNode = new MultiTaskTrainingNode((IndexedSeq) ((TraversableLike) seq.indices().map(obj -> {
            return $anonfun$train$8(option, seq4, vector2, BoxesRunTime.unboxToInt(obj));
        }, IndexedSeq$.MODULE$.canBuildFrom())).filter(tuple3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$train$11(tuple3));
        }));
        IndexedSeq indexedSeq = (IndexedSeq) seq2.indices().map(obj2 -> {
            return multiTaskTrainingNode.getNode(BoxesRunTime.unboxToInt(obj2));
        }, IndexedSeq$.MODULE$.canBuildFrom());
        return (Seq) ((IndexedSeq) seq2.indices().map(obj3 -> {
            return $anonfun$train$13(seq2, seq3, seq5, indexedSeq, BoxesRunTime.unboxToInt(obj3));
        }, IndexedSeq$.MODULE$.canBuildFrom())).map(model -> {
            return new MultiTaskTreeTrainingResult(model, this.hypers());
        }, IndexedSeq$.MODULE$.canBuildFrom());
    }

    public static final /* synthetic */ boolean $anonfun$train$6(Object obj) {
        return obj == null;
    }

    public static final /* synthetic */ double $anonfun$train$9(int i, Seq seq) {
        return BoxesRunTime.unboxToDouble(seq.apply(i));
    }

    public static final /* synthetic */ Tuple3 $anonfun$train$8(Option option, Seq seq, Vector vector, int i) {
        return new Tuple3(seq.apply(i), ((TraversableOnce) vector.apply(i)).toArray(ClassTag$.MODULE$.AnyVal()), option.map(seq2 -> {
            return BoxesRunTime.boxToDouble($anonfun$train$9(i, seq2));
        }).getOrElse(() -> {
            return 1.0d;
        }));
    }

    public static final /* synthetic */ boolean $anonfun$train$11(Tuple3 tuple3) {
        return BoxesRunTime.unboxToDouble(tuple3._3()) > 0.0d;
    }

    public static final /* synthetic */ Model $anonfun$train$13(Seq seq, Seq seq2, Seq seq3, IndexedSeq indexedSeq, int i) {
        return ((IterableLike) seq.apply(i)).head() instanceof Double ? new RegressionTree((ModelNode) indexedSeq.apply(i), seq2) : new ClassificationTree((ModelNode) indexedSeq.apply(i), seq2, (CategoricalEncoder) ((Option) seq3.apply(i)).get());
    }

    public MultiTaskTreeLearner() {
        hypers_$eq((Map) Predef$.MODULE$.Map().apply(Nil$.MODULE$));
    }
}
