/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.translation;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ReadTranslation;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils;
import org.apache.beam.runners.fnexecution.wire.WireCoders;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
import org.apache.beam.runners.spark.io.SourceRDD;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.translation.BoundedDataset;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.GroupNonMergingWindowsFunctions;
import org.apache.beam.runners.spark.translation.SparkExecutableStageExtractionFunction;
import org.apache.beam.runners.spark.translation.SparkExecutableStageFunction;
import org.apache.beam.runners.spark.translation.SparkGroupAlsoByWindowViaOutputBufferFn;
import org.apache.beam.runners.spark.translation.SparkTranslationContext;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.BiMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Sets;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import scala.Tuple2;

public class SparkBatchPortablePipelineTranslator {
    private final ImmutableMap<String, PTransformTranslator> urnToTransformTranslator;

    public Set<String> knownUrns() {
        return Sets.difference((Set)this.urnToTransformTranslator.keySet(), (Set)ImmutableSet.of((Object)PTransformTranslation.READ_TRANSFORM_URN));
    }

    public SparkBatchPortablePipelineTranslator() {
        ImmutableMap.Builder translatorMap = ImmutableMap.builder();
        translatorMap.put((Object)PTransformTranslation.IMPULSE_TRANSFORM_URN, SparkBatchPortablePipelineTranslator::translateImpulse);
        translatorMap.put((Object)PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, SparkBatchPortablePipelineTranslator::translateGroupByKey);
        translatorMap.put((Object)"beam:runner:executable_stage:v1", SparkBatchPortablePipelineTranslator::translateExecutableStage);
        translatorMap.put((Object)PTransformTranslation.FLATTEN_TRANSFORM_URN, SparkBatchPortablePipelineTranslator::translateFlatten);
        translatorMap.put((Object)PTransformTranslation.READ_TRANSFORM_URN, SparkBatchPortablePipelineTranslator::translateRead);
        this.urnToTransformTranslator = translatorMap.build();
    }

    public void translate(RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        QueryablePipeline p = QueryablePipeline.forTransforms((Collection)pipeline.getRootTransformIdsList(), (RunnerApi.Components)pipeline.getComponents());
        for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
            for (String inputId : transformNode.getTransform().getInputsMap().values()) {
                context.incrementConsumptionCountBy(inputId, 1);
            }
            if (transformNode.getTransform().getSpec().getUrn().equals("beam:runner:executable_stage:v1")) {
                context.incrementConsumptionCountBy(SparkBatchPortablePipelineTranslator.getExecutableStageIntermediateId(transformNode), transformNode.getTransform().getOutputsMap().size());
            }
            for (String outputId : transformNode.getTransform().getOutputsMap().values()) {
                WindowedValue.WindowedValueCoder outputCoder = SparkBatchPortablePipelineTranslator.getWindowedValueCoder(outputId, pipeline.getComponents());
                context.putCoder(outputId, (Coder)outputCoder);
            }
        }
        for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
            ((PTransformTranslator)this.urnToTransformTranslator.getOrDefault((Object)transformNode.getTransform().getSpec().getUrn(), SparkBatchPortablePipelineTranslator::urnNotFound)).translate(transformNode, pipeline, context);
        }
    }

    private static void urnNotFound(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        throw new IllegalArgumentException(String.format("Transform %s has unknown URN %s", transformNode.getId(), transformNode.getTransform().getSpec().getUrn()));
    }

    private static void translateImpulse(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        BoundedDataset<byte[]> output = new BoundedDataset<byte[]>((Iterable<byte[]>)Collections.singletonList(new byte[0]), context.getSparkContext(), (Coder<byte[]>)ByteArrayCoder.of());
        context.pushDataset(SparkBatchPortablePipelineTranslator.getOutputId(transformNode), output);
    }

    private static <K, V> void translateGroupByKey(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        JavaRDD groupedByKeyAndWindow;
        RunnerApi.Components components = pipeline.getComponents();
        String inputId = SparkBatchPortablePipelineTranslator.getInputId(transformNode);
        Dataset inputDataset = context.popDataset(inputId);
        JavaRDD inputRdd = ((BoundedDataset)inputDataset).getRDD();
        WindowedValue.WindowedValueCoder inputCoder = SparkBatchPortablePipelineTranslator.getWindowedValueCoder(inputId, components);
        KvCoder inputKvCoder = (KvCoder)inputCoder.getValueCoder();
        Coder inputKeyCoder = inputKvCoder.getKeyCoder();
        Coder inputValueCoder = inputKvCoder.getValueCoder();
        WindowingStrategy windowingStrategy = PipelineTranslatorUtils.getWindowingStrategy((String)inputId, (RunnerApi.Components)components);
        WindowFn windowFn = windowingStrategy.getWindowFn();
        WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)inputValueCoder, (Coder)windowFn.windowCoder());
        if (windowingStrategy.getWindowFn().isNonMerging() && windowingStrategy.getTimestampCombiner() == TimestampCombiner.END_OF_WINDOW) {
            groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyAndWindow(inputRdd, inputKeyCoder, inputValueCoder, windowingStrategy);
        } else {
            Partitioner partitioner = SparkBatchPortablePipelineTranslator.getPartitioner(context);
            JavaRDD groupedByKeyOnly = GroupCombineFunctions.groupByKeyOnly(inputRdd, inputKeyCoder, wvCoder, partitioner);
            groupedByKeyAndWindow = groupedByKeyOnly.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory(), SystemReduceFn.buffering((Coder)inputValueCoder), context.serializablePipelineOptions, AggregatorsAccumulator.getInstance()));
        }
        context.pushDataset(SparkBatchPortablePipelineTranslator.getOutputId(transformNode), new BoundedDataset(groupedByKeyAndWindow));
    }

    private static <InputT, OutputT, SideInputT> void translateExecutableStage(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        RunnerApi.ExecutableStagePayload stagePayload;
        try {
            stagePayload = RunnerApi.ExecutableStagePayload.parseFrom((ByteString)transformNode.getTransform().getSpec().getPayload());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        String inputPCollectionId = stagePayload.getInput();
        Dataset inputDataset = context.popDataset(inputPCollectionId);
        JavaRDD inputRdd = ((BoundedDataset)inputDataset).getRDD();
        Map outputs = transformNode.getTransform().getOutputsMap();
        BiMap outputExtractionMap = PipelineTranslatorUtils.createOutputMap(outputs.values());
        ImmutableMap.Builder broadcastVariablesBuilder = ImmutableMap.builder();
        for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : stagePayload.getSideInputsList()) {
            RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
            String collectionId = stagePayloadComponents.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
            Tuple2 tuple2 = SparkBatchPortablePipelineTranslator.broadcastSideInput(collectionId, stagePayloadComponents, context);
            broadcastVariablesBuilder.put((Object)collectionId, tuple2);
        }
        SparkExecutableStageFunction function = new SparkExecutableStageFunction(stagePayload, context.jobInfo, (Map<String, Integer>)outputExtractionMap, broadcastVariablesBuilder.build(), MetricsAccumulator.getInstance());
        final JavaRDD staged = inputRdd.mapPartitions(function);
        String intermediateId = SparkBatchPortablePipelineTranslator.getExecutableStageIntermediateId(transformNode);
        context.pushDataset(intermediateId, new Dataset(){

            @Override
            public void cache(String storageLevel, Coder<?> coder) {
                StorageLevel level = StorageLevel.fromString((String)storageLevel);
                staged.persist(level);
            }

            @Override
            public void action() {
                staged.foreach(TranslationUtils.emptyVoidFunction());
            }

            @Override
            public void setName(String name) {
                staged.setName(name);
            }
        });
        context.popDataset(intermediateId);
        for (String outputId : outputs.values()) {
            JavaRDD outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction((Integer)outputExtractionMap.get((Object)outputId)));
            context.pushDataset(outputId, new BoundedDataset(outputRdd));
        }
        if (outputs.isEmpty()) {
            JavaRDD outputRdd = staged.flatMap((FlatMapFunction & Serializable)rawUnionValue -> Collections.emptyIterator());
            context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset(outputRdd));
        }
    }

    private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<T>> broadcastSideInput(String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
        BoundedDataset dataset = (BoundedDataset)context.popDataset(collectionId);
        WindowedValue.WindowedValueCoder<T> coder = SparkBatchPortablePipelineTranslator.getWindowedValueCoder(collectionId, components);
        List<byte[]> bytes = dataset.getBytes(coder);
        Broadcast broadcast = context.getSparkContext().broadcast(bytes);
        return new Tuple2((Object)broadcast, coder);
    }

    private static <T> void translateFlatten(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        JavaRDD unionRDD;
        Map inputsMap = transformNode.getTransform().getInputsMap();
        if (inputsMap.isEmpty()) {
            unionRDD = context.getSparkContext().emptyRDD();
        } else {
            JavaRDD[] rdds = new JavaRDD[inputsMap.size()];
            int index = 0;
            for (String inputId : inputsMap.values()) {
                rdds[index] = ((BoundedDataset)context.popDataset(inputId)).getRDD();
                ++index;
            }
            unionRDD = context.getSparkContext().union(rdds);
        }
        context.pushDataset(SparkBatchPortablePipelineTranslator.getOutputId(transformNode), new BoundedDataset(unionRDD));
    }

    private static <T> void translateRead(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        BoundedSource boundedSource;
        String stepName = transformNode.getTransform().getUniqueName();
        JavaSparkContext jsc = context.getSparkContext();
        try {
            boundedSource = ReadTranslation.boundedSourceFromProto((RunnerApi.ReadPayload)RunnerApi.ReadPayload.parseFrom((ByteString)transformNode.getTransform().getSpec().getPayload()));
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to extract BoundedSource from ReadPayload.", e);
        }
        JavaRDD input = new SourceRDD.Bounded(jsc.sc(), boundedSource, context.serializablePipelineOptions, stepName).toJavaRDD();
        context.pushDataset(SparkBatchPortablePipelineTranslator.getOutputId(transformNode), new BoundedDataset(input));
    }

    @Nullable
    private static Partitioner getPartitioner(SparkTranslationContext context) {
        Long bundleSize = ((SparkPipelineOptions)context.serializablePipelineOptions.get().as(SparkPipelineOptions.class)).getBundleSize();
        return bundleSize > 0L ? null : new HashPartitioner(context.getSparkContext().defaultParallelism().intValue());
    }

    private static String getInputId(PipelineNode.PTransformNode transformNode) {
        return (String)Iterables.getOnlyElement(transformNode.getTransform().getInputsMap().values());
    }

    private static String getOutputId(PipelineNode.PTransformNode transformNode) {
        return (String)Iterables.getOnlyElement(transformNode.getTransform().getOutputsMap().values());
    }

    private static <T> WindowedValue.WindowedValueCoder<T> getWindowedValueCoder(String pCollectionId, RunnerApi.Components components) {
        WindowedValue.WindowedValueCoder coder;
        RunnerApi.PCollection pCollection = components.getPcollectionsOrThrow(pCollectionId);
        PipelineNode.PCollectionNode pCollectionNode = PipelineNode.pCollection((String)pCollectionId, (RunnerApi.PCollection)pCollection);
        try {
            coder = (WindowedValue.WindowedValueCoder)WireCoders.instantiateRunnerWireCoder((PipelineNode.PCollectionNode)pCollectionNode, (RunnerApi.Components)components);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return coder;
    }

    private static String getExecutableStageIntermediateId(PipelineNode.PTransformNode transformNode) {
        return transformNode.getId();
    }

    static interface PTransformTranslator {
        public void translate(PipelineNode.PTransformNode var1, RunnerApi.Pipeline var2, SparkTranslationContext var3);
    }
}

