/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tinkerpop.gremlin.spark.process.computer.traversal.strategy.optimization.interceptor;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.function.BinaryOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.tinkerpop.gremlin.hadoop.structure.io.VertexWritable;
import org.apache.tinkerpop.gremlin.process.computer.Memory;
import org.apache.tinkerpop.gremlin.process.computer.ProgramPhase;
import org.apache.tinkerpop.gremlin.process.computer.traversal.MemoryTraversalSideEffects;
import org.apache.tinkerpop.gremlin.process.computer.traversal.TraversalVertexProgram;
import org.apache.tinkerpop.gremlin.process.computer.traversal.strategy.finalization.ComputerFinalizationStrategy;
import org.apache.tinkerpop.gremlin.process.traversal.Scope;
import org.apache.tinkerpop.gremlin.process.traversal.Step;
import org.apache.tinkerpop.gremlin.process.traversal.Traversal;
import org.apache.tinkerpop.gremlin.process.traversal.Traverser;
import org.apache.tinkerpop.gremlin.process.traversal.step.Barrier;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.CountGlobalStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.FoldStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.GraphStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.GroupCountStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.GroupStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.MaxGlobalStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.MeanGlobalStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.MinGlobalStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.SumGlobalStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.util.ReducingBarrierStep;
import org.apache.tinkerpop.gremlin.process.traversal.strategy.decoration.SubgraphStrategy;
import org.apache.tinkerpop.gremlin.process.traversal.strategy.verification.ComputerVerificationStrategy;
import org.apache.tinkerpop.gremlin.process.traversal.traverser.TraverserRequirement;
import org.apache.tinkerpop.gremlin.process.traversal.traverser.util.TraverserSet;
import org.apache.tinkerpop.gremlin.process.traversal.util.TraversalHelper;
import org.apache.tinkerpop.gremlin.spark.process.computer.SparkMemory;
import org.apache.tinkerpop.gremlin.spark.process.computer.traversal.strategy.SparkVertexProgramInterceptor;
import org.apache.tinkerpop.gremlin.structure.util.ElementHelper;
import org.apache.tinkerpop.gremlin.util.NumberHelper;
import org.apache.tinkerpop.gremlin.util.function.ArrayListSupplier;
import org.apache.tinkerpop.gremlin.util.function.MeanNumberSupplier;
import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils;

public final class SparkStarBarrierInterceptor
implements SparkVertexProgramInterceptor<TraversalVertexProgram> {
    public JavaPairRDD<Object, VertexWritable> apply(TraversalVertexProgram vertexProgram, JavaPairRDD<Object, VertexWritable> inputRDD, SparkMemory memory) {
        BinaryOperator biOperator;
        Object result;
        vertexProgram.setup((Memory)memory);
        Traversal.Admin traversal = vertexProgram.getTraversal().getPure().clone();
        GraphStep graphStep = (GraphStep)traversal.getStartStep();
        Object[] graphStepIds = graphStep.getIds();
        ReducingBarrierStep endStep = (ReducingBarrierStep)traversal.getEndStep();
        traversal.removeStep(0);
        traversal.removeStep(traversal.getSteps().size() - 1);
        traversal.setStrategies(traversal.clone().getStrategies().removeStrategies(new Class[]{ComputerVerificationStrategy.class, ComputerFinalizationStrategy.class}));
        traversal.applyStrategies();
        boolean identityTraversal = traversal.getSteps().isEmpty();
        MemoryTraversalSideEffects.setMemorySideEffects((Traversal.Admin)traversal, (Memory)memory, (ProgramPhase)ProgramPhase.EXECUTE);
        memory.setInExecute(true);
        JavaRDD nextRDD = inputRDD.values().filter((Function & Serializable)vertexWritable -> ElementHelper.idExists((Object)vertexWritable.get().id(), (Object[])graphStepIds)).flatMap((FlatMapFunction & Serializable)vertexWritable -> {
            if (identityTraversal) {
                return IteratorUtils.of((Object)traversal.getTraverserGenerator().generate((Object)vertexWritable.get(), (Step)graphStep, 1L));
            }
            Traversal.Admin clone = traversal.clone();
            clone.getStartStep().addStart(clone.getTraverserGenerator().generate((Object)vertexWritable.get(), (Step)graphStep, 1L));
            return clone.getEndStep();
        });
        if (endStep instanceof CountGlobalStep) {
            result = nextRDD.map(Traverser::bulk).fold((Object)0L, (Function2 & Serializable)(a, b) -> a + b);
        } else if (endStep instanceof SumGlobalStep) {
            result = nextRDD.isEmpty() ? null : nextRDD.map((Function & Serializable)traverser -> NumberHelper.mul((Number)traverser.bulk(), (Number)((Number)traverser.get()))).fold((Object)0, NumberHelper::add);
        } else if (endStep instanceof MeanGlobalStep) {
            result = nextRDD.isEmpty() ? (Number)null : (Number)((MeanGlobalStep.MeanNumber)nextRDD.map((Function & Serializable)traverser -> new MeanGlobalStep.MeanNumber((Number)traverser.get(), traverser.bulk())).fold((Object)MeanNumberSupplier.instance().get(), MeanGlobalStep.MeanNumber::add)).getFinal();
        } else if (endStep instanceof MinGlobalStep) {
            result = nextRDD.isEmpty() ? null : nextRDD.map((Function & Serializable)traverser -> (Comparable)traverser.get()).fold((Object)Double.NaN, NumberHelper::min);
        } else if (endStep instanceof MaxGlobalStep) {
            result = nextRDD.isEmpty() ? null : nextRDD.map((Function & Serializable)traverser -> (Comparable)traverser.get()).fold((Object)Double.NaN, NumberHelper::max);
        } else if (endStep instanceof FoldStep) {
            biOperator = endStep.getBiOperator();
            result = nextRDD.map((Function & Serializable)traverser -> {
                if (endStep.getSeedSupplier() instanceof ArrayListSupplier) {
                    ArrayList<Object> list = new ArrayList<Object>();
                    for (long i = 0L; i < traverser.bulk(); ++i) {
                        list.add(traverser.get());
                    }
                    return list;
                }
                return traverser.get();
            }).fold(endStep.getSeedSupplier().get(), biOperator::apply);
        } else if (endStep instanceof GroupStep) {
            biOperator = (GroupStep.GroupBiOperator)endStep.getBiOperator();
            result = ((GroupStep)endStep).generateFinalResult((Map)nextRDD.mapPartitions((FlatMapFunction & Serializable)partitions -> {
                GroupStep clone = (GroupStep)endStep.clone();
                return IteratorUtils.map((Iterator)partitions, arg_0 -> ((GroupStep)clone).projectTraverser(arg_0));
            }).fold(((GroupStep)endStep).getSeedSupplier().get(), (arg_0, arg_1) -> ((GroupStep.GroupBiOperator)biOperator).apply(arg_0, arg_1)));
        } else if (endStep instanceof GroupCountStep) {
            biOperator = GroupCountStep.GroupCountBiOperator.instance();
            result = nextRDD.mapPartitions((FlatMapFunction & Serializable)partitions -> {
                GroupCountStep clone = (GroupCountStep)endStep.clone();
                return IteratorUtils.map((Iterator)partitions, arg_0 -> ((GroupCountStep)clone).projectTraverser(arg_0));
            }).fold(((GroupCountStep)endStep).getSeedSupplier().get(), (arg_0, arg_1) -> ((GroupCountStep.GroupCountBiOperator)biOperator).apply(arg_0, arg_1));
        } else {
            throw new IllegalArgumentException("The end step is an unsupported barrier: " + endStep);
        }
        memory.setInExecute(false);
        if (result != null) {
            TraverserSet haltedTraversers = new TraverserSet();
            haltedTraversers.add(traversal.getTraverserGenerator().generate(result, (Step)endStep, 1L));
            memory.set("gremlin.traversalVertexProgram.haltedTraversers", haltedTraversers);
        }
        memory.incrIteration();
        return inputRDD;
    }

    public static boolean isLegal(Traversal.Admin<?, ?> traversal) {
        Step startStep = traversal.getStartStep();
        Step endStep = traversal.getEndStep();
        if (traversal.getStrategies().toList().stream().filter(strategy -> strategy instanceof SubgraphStrategy).findAny().isPresent()) {
            return false;
        }
        if (!startStep.getClass().equals(GraphStep.class) || ((GraphStep)startStep).returnsEdge()) {
            return false;
        }
        if (!(endStep.getClass().equals(CountGlobalStep.class) || endStep.getClass().equals(SumGlobalStep.class) || endStep.getClass().equals(MeanGlobalStep.class) || endStep.getClass().equals(MaxGlobalStep.class) || endStep.getClass().equals(MinGlobalStep.class) || endStep.getClass().equals(FoldStep.class) || endStep.getClass().equals(GroupStep.class) || endStep.getClass().equals(GroupCountStep.class))) {
            return false;
        }
        if (TraversalHelper.getStepsOfAssignableClassRecursively((Scope)Scope.global, Barrier.class, traversal).size() != 1) {
            return false;
        }
        if (traversal.getTraverserRequirements().contains(TraverserRequirement.SACK)) {
            return false;
        }
        return TraversalHelper.isLocalStarGraph(traversal);
    }
}

