/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.spark.classloader_interface;

import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkPage;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedTaskStats;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.util.CollectionAccumulator;
import scala.Tuple2;

public class TaskProcessors {
    private TaskProcessors() {
    }

    public static PairFlatMapFunction<Iterator<SerializedPrestoSparkTaskDescriptor>, Integer, SerializedPrestoSparkPage> createTaskProcessor(final PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider, final CollectionAccumulator<SerializedTaskStats> taskStatsCollector) {
        return new PairFlatMapFunction<Iterator<SerializedPrestoSparkTaskDescriptor>, Integer, SerializedPrestoSparkPage>(){

            public Iterator<Tuple2<Integer, SerializedPrestoSparkPage>> call(Iterator<SerializedPrestoSparkTaskDescriptor> serializedTaskRequestIterator) {
                SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = serializedTaskRequestIterator.next();
                if (serializedTaskRequestIterator.hasNext()) {
                    throw new IllegalArgumentException("each partition is expected to contain an exactly one task descriptor");
                }
                int partitionId = TaskContext.get().partitionId();
                int attemptNumber = TaskContext.get().attemptNumber();
                return taskExecutorFactoryProvider.get().create(partitionId, attemptNumber, serializedTaskDescriptor, new PrestoSparkTaskInputs(Collections.emptyMap()), (CollectionAccumulator<SerializedTaskStats>)taskStatsCollector);
            }
        };
    }

    public static PairFlatMapFunction<Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>, Integer, SerializedPrestoSparkPage> createTaskProcessor(final PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider, final SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, final String planNodeId, final CollectionAccumulator<SerializedTaskStats> taskStatsCollector) {
        return new PairFlatMapFunction<Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>, Integer, SerializedPrestoSparkPage>(){

            public Iterator<Tuple2<Integer, SerializedPrestoSparkPage>> call(Iterator<Tuple2<Integer, SerializedPrestoSparkPage>> input) {
                int partitionId = TaskContext.get().partitionId();
                int attemptNumber = TaskContext.get().attemptNumber();
                return taskExecutorFactoryProvider.get().create(partitionId, attemptNumber, serializedTaskDescriptor, new PrestoSparkTaskInputs(Collections.singletonMap(planNodeId, input)), (CollectionAccumulator<SerializedTaskStats>)taskStatsCollector);
            }
        };
    }

    public static FlatMapFunction2<Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>, Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>, Tuple2<Integer, SerializedPrestoSparkPage>> createTaskProcessor(final PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider, final SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, final String planNodeId1, final String planNodeId2, final CollectionAccumulator<SerializedTaskStats> taskStatsCollector) {
        return new FlatMapFunction2<Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>, Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>, Tuple2<Integer, SerializedPrestoSparkPage>>(){

            public Iterator<Tuple2<Integer, SerializedPrestoSparkPage>> call(Iterator<Tuple2<Integer, SerializedPrestoSparkPage>> input1, Iterator<Tuple2<Integer, SerializedPrestoSparkPage>> input2) {
                int partitionId = TaskContext.get().partitionId();
                int attemptNumber = TaskContext.get().attemptNumber();
                HashMap<String, Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>> inputsMap = new HashMap<String, Iterator<Tuple2<Integer, SerializedPrestoSparkPage>>>();
                inputsMap.put(planNodeId1, input1);
                inputsMap.put(planNodeId2, input2);
                return taskExecutorFactoryProvider.get().create(partitionId, attemptNumber, serializedTaskDescriptor, new PrestoSparkTaskInputs(Collections.unmodifiableMap(inputsMap)), (CollectionAccumulator<SerializedTaskStats>)taskStatsCollector);
            }
        };
    }
}

