package co.cask.cdap.etl.spark;

import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import co.cask.cdap.etl.api.JoinElement;
import co.cask.cdap.etl.api.Transform;
import co.cask.cdap.etl.api.batch.BatchAggregator;
import co.cask.cdap.etl.api.batch.BatchJoiner;
import co.cask.cdap.etl.api.batch.BatchSink;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkSink;
import co.cask.cdap.etl.api.streaming.Windower;
import co.cask.cdap.etl.common.DefaultMacroEvaluator;
import co.cask.cdap.etl.common.PipelinePhase;
import co.cask.cdap.etl.planner.StageInfo;
import co.cask.cdap.etl.spark.function.BatchSinkFunction;
import co.cask.cdap.etl.spark.function.InitialJoinFunction;
import co.cask.cdap.etl.spark.function.JoinFlattenFunction;
import co.cask.cdap.etl.spark.function.LeftJoinFlattenFunction;
import co.cask.cdap.etl.spark.function.OuterJoinFlattenFunction;
import co.cask.cdap.etl.spark.function.PluginFunctionContext;
import co.cask.cdap.etl.spark.function.TransformFunction;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:lib/hydrator-spark-core-4.0.1.jar:co/cask/cdap/etl/spark/SparkPipelineRunner.class */
public abstract class SparkPipelineRunner {
    protected abstract SparkCollection<Object> getSource(StageInfo stageInfo) throws Exception;

    protected abstract SparkPairCollection<Object, Object> addJoinKey(StageInfo stageInfo, String str, SparkCollection<Object> sparkCollection) throws Exception;

    protected abstract SparkCollection<Object> mergeJoinResults(StageInfo stageInfo, SparkPairCollection<Object, List<JoinElement<Object>>> sparkPairCollection) throws Exception;

    public void runPipeline(PipelinePhase pipelinePhase, String str, JavaSparkExecutionContext javaSparkExecutionContext, Map<String, Integer> map) throws Exception {
        SparkPairCollection<Object, List<JoinElement<Object>>> mapValues;
        DefaultMacroEvaluator defaultMacroEvaluator = new DefaultMacroEvaluator(javaSparkExecutionContext.getWorkflowToken(), javaSparkExecutionContext.getRuntimeArguments(), javaSparkExecutionContext.getLogicalStartTime(), javaSparkExecutionContext, javaSparkExecutionContext.getNamespace());
        HashMap hashMap = new HashMap();
        if (pipelinePhase.getDag() == null) {
            throw new IllegalStateException("Pipeline phase has no connections.");
        }
        for (String str2 : pipelinePhase.getDag().getTopologicalOrder()) {
            StageInfo stage = pipelinePhase.getStage(str2);
            String pluginType = stage.getPluginType();
            SparkCollection<Object> sparkCollection = null;
            HashMap hashMap2 = new HashMap();
            for (String str3 : stage.getInputs()) {
                hashMap2.put(str3, hashMap.get(str3));
            }
            if (!hashMap2.isEmpty()) {
                Iterator it = hashMap2.values().iterator();
                SparkCollection<Object> sparkCollection2 = (SparkCollection) it.next();
                while (true) {
                    sparkCollection = sparkCollection2;
                    if (BatchJoiner.PLUGIN_TYPE.equals(pluginType) || !it.hasNext()) {
                        break;
                    } else {
                        sparkCollection2 = sparkCollection.union((SparkCollection) it.next());
                    }
                }
            }
            PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stage, javaSparkExecutionContext);
            if (sparkCollection == null) {
                if (!str.equals(pluginType)) {
                    throw new IllegalStateException(String.format("Stage '%s' has no input and is not a source.", str2));
                }
                sparkCollection = getSource(stage);
            } else if (BatchSink.PLUGIN_TYPE.equals(pluginType)) {
                sparkCollection.store(stage, new BatchSinkFunction(pluginFunctionContext));
            } else if (Transform.PLUGIN_TYPE.equals(pluginType)) {
                sparkCollection = sparkCollection.flatMap(stage, new TransformFunction(pluginFunctionContext));
            } else if (SparkCompute.PLUGIN_TYPE.equals(pluginType)) {
                sparkCollection = sparkCollection.compute(stage, (SparkCompute) javaSparkExecutionContext.getPluginContext().newPluginInstance(str2, defaultMacroEvaluator));
            } else if (SparkSink.PLUGIN_TYPE.equals(pluginType)) {
                sparkCollection.store(stage, (SparkSink<Object>) javaSparkExecutionContext.getPluginContext().newPluginInstance(str2, defaultMacroEvaluator));
            } else if (BatchAggregator.PLUGIN_TYPE.equals(pluginType)) {
                sparkCollection = sparkCollection.aggregate(stage, map.get(str2));
            } else if (BatchJoiner.PLUGIN_TYPE.equals(pluginType)) {
                BatchJoiner batchJoiner = (BatchJoiner) javaSparkExecutionContext.getPluginContext().newPluginInstance(str2, defaultMacroEvaluator);
                batchJoiner.initialize(pluginFunctionContext.createJoinerRuntimeContext());
                HashMap hashMap3 = new HashMap();
                for (Map.Entry entry : hashMap2.entrySet()) {
                    String str4 = (String) entry.getKey();
                    hashMap3.put(str4, addJoinKey(stage, str4, (SparkCollection) entry.getValue()));
                }
                HashSet<String> hashSet = new HashSet();
                hashSet.addAll(hashMap2.keySet());
                Integer num = map.get(str2);
                SparkPairCollection<Object, List<JoinElement<Object>>> sparkPairCollection = null;
                for (String str5 : batchJoiner.getJoinConfig().getRequiredInputs()) {
                    SparkPairCollection<Object, T> sparkPairCollection2 = (SparkPairCollection) hashMap3.get(str5);
                    if (sparkPairCollection == null) {
                        mapValues = sparkPairCollection2.mapValues(new InitialJoinFunction(str5));
                    } else {
                        JoinFlattenFunction joinFlattenFunction = new JoinFlattenFunction(str5);
                        mapValues = num == null ? sparkPairCollection.join(sparkPairCollection2).mapValues(joinFlattenFunction) : sparkPairCollection.join(sparkPairCollection2, num.intValue()).mapValues(joinFlattenFunction);
                    }
                    sparkPairCollection = mapValues;
                    hashSet.remove(str5);
                }
                boolean z = sparkPairCollection == null;
                for (String str6 : hashSet) {
                    SparkPairCollection<Object, T> sparkPairCollection3 = (SparkPairCollection) hashMap3.get(str6);
                    if (sparkPairCollection == null) {
                        sparkPairCollection = sparkPairCollection3.mapValues(new InitialJoinFunction(str6));
                    } else if (z) {
                        OuterJoinFlattenFunction outerJoinFlattenFunction = new OuterJoinFlattenFunction(str6);
                        sparkPairCollection = num == null ? sparkPairCollection.fullOuterJoin(sparkPairCollection3).mapValues(outerJoinFlattenFunction) : sparkPairCollection.fullOuterJoin(sparkPairCollection3, num.intValue()).mapValues(outerJoinFlattenFunction);
                    } else {
                        LeftJoinFlattenFunction leftJoinFlattenFunction = new LeftJoinFlattenFunction(str6);
                        sparkPairCollection = num == null ? sparkPairCollection.leftOuterJoin(sparkPairCollection3).mapValues(leftJoinFlattenFunction) : sparkPairCollection.leftOuterJoin(sparkPairCollection3, num.intValue()).mapValues(leftJoinFlattenFunction);
                    }
                }
                if (sparkPairCollection == null) {
                    throw new IllegalStateException("There are no inputs into join stage " + str2);
                }
                sparkCollection = mergeJoinResults(stage, sparkPairCollection).cache();
            } else {
                if (!Windower.PLUGIN_TYPE.equals(pluginType)) {
                    throw new IllegalStateException(String.format("Stage %s is of unsupported plugin type %s.", str2, pluginType));
                }
                sparkCollection = sparkCollection.window(stage, (Windower) javaSparkExecutionContext.getPluginContext().newPluginInstance(str2, defaultMacroEvaluator));
            }
            if (shouldCache(pipelinePhase, stage)) {
                sparkCollection = sparkCollection.cache();
            }
            hashMap.put(str2, sparkCollection);
        }
    }

    private boolean shouldCache(PipelinePhase pipelinePhase, StageInfo stageInfo) {
        Set<String> stageOutputs = pipelinePhase.getStageOutputs(stageInfo.getName());
        if (stageOutputs.size() > 1) {
            return true;
        }
        Iterator<String> it = stageOutputs.iterator();
        while (it.hasNext()) {
            if (pipelinePhase.getStage(it.next()).getInputs().size() > 1) {
                return true;
            }
        }
        return false;
    }
}
