package cz.seznam.euphoria.spark;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import cz.seznam.euphoria.core.client.dataset.windowing.MergingWindowing;
import cz.seznam.euphoria.core.client.dataset.windowing.Window;
import cz.seznam.euphoria.core.client.dataset.windowing.Windowing;
import cz.seznam.euphoria.core.client.functional.UnaryFunction;
import cz.seznam.euphoria.core.client.operator.Join;
import cz.seznam.euphoria.core.client.util.Either;
import cz.seznam.euphoria.core.client.util.Pair;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.Optional;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;

/* loaded from: input_file:cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.class */
public class BroadcastHashJoinTranslator implements SparkOperatorTranslator<Join> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: cz.seznam.euphoria.spark.BroadcastHashJoinTranslator$1, reason: invalid class name */
    /* loaded from: input_file:cz/seznam/euphoria/spark/BroadcastHashJoinTranslator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$cz$seznam$euphoria$core$client$operator$Join$Type = new int[Join.Type.values().length];

        static {
            try {
                $SwitchMap$cz$seznam$euphoria$core$client$operator$Join$Type[Join.Type.LEFT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$cz$seznam$euphoria$core$client$operator$Join$Type[Join.Type.RIGHT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cz/seznam/euphoria/spark/BroadcastHashJoinTranslator$KeyExtractor.class */
    public static class KeyExtractor implements PairFlatMapFunction<SparkElement, KeyedWindow, SparkElement> {
        private final UnaryFunction keyExtractor;
        private final Windowing windowing;
        private final boolean left;

        KeyExtractor(UnaryFunction unaryFunction, Windowing windowing, boolean z) {
            this.keyExtractor = unaryFunction;
            this.windowing = windowing;
            this.left = z;
        }

        public Iterator<Tuple2<KeyedWindow, SparkElement>> call(SparkElement sparkElement) throws Exception {
            return Iterators.transform(this.windowing.assignWindowsToElement(new SparkElement(sparkElement.getWindow(), sparkElement.getTimestamp(), this.left ? Either.left(sparkElement.getElement()) : Either.right(sparkElement.getElement()))).iterator(), window -> {
                return new Tuple2(new KeyedWindow(window, sparkElement.getTimestamp(), this.keyExtractor.apply(sparkElement.getElement())), new SparkElement(window, sparkElement.getTimestamp(), sparkElement.getElement()));
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean wantTranslate(Join join) {
        return join.getHints().contains(JoinHints.broadcastHashJoin()) && (join.getType() == Join.Type.LEFT || join.getType() == Join.Type.RIGHT) && !(join.getWindowing() instanceof MergingWindowing);
    }

    @Override // cz.seznam.euphoria.spark.SparkOperatorTranslator
    public JavaRDD<?> translate(Join join, SparkExecutorContext sparkExecutorContext) {
        JavaPairRDD flatMapToPair;
        Preconditions.checkArgument(join.getHints().contains(JoinHints.broadcastHashJoin()), "Missing broadcastHashJoin hint");
        Preconditions.checkArgument(join.getType() == Join.Type.LEFT || join.getType() == Join.Type.RIGHT, "BroadcastJoin supports LEFT and RIGHT joins only");
        List<JavaRDD<?>> inputs = sparkExecutorContext.getInputs(join);
        JavaRDD<?> javaRDD = inputs.get(0);
        JavaRDD<?> javaRDD2 = inputs.get(1);
        Windowing windowing = join.getWindowing() == null ? AttachedWindowing.INSTANCE : join.getWindowing();
        JavaPairRDD flatMapToPair2 = javaRDD.flatMapToPair(new KeyExtractor(join.getLeftKeyExtractor(), windowing, true));
        JavaPairRDD flatMapToPair3 = javaRDD2.flatMapToPair(new KeyExtractor(join.getRightKeyExtractor(), windowing, false));
        switch (AnonymousClass1.$SwitchMap$cz$seznam$euphoria$core$client$operator$Join$Type[join.getType().ordinal()]) {
            case 1:
                Broadcast broadcast = sparkExecutorContext.getExecutionEnvironment().broadcast(toBroadcast(flatMapToPair3.collect()));
                flatMapToPair = flatMapToPair2.flatMapToPair(tuple2 -> {
                    return ((Map) broadcast.getValue()).containsKey(tuple2._1) ? Iterables.transform((Iterable) ((Map) broadcast.getValue()).get(tuple2._1), sparkElement -> {
                        return new Tuple2(tuple2._1, new Tuple2(opt(tuple2._2), opt(sparkElement)));
                    }).iterator() : Collections.singletonList(new Tuple2(tuple2._1, new Tuple2(opt(tuple2._2), Optional.empty()))).iterator();
                });
                break;
            case 2:
                Broadcast broadcast2 = sparkExecutorContext.getExecutionEnvironment().broadcast(toBroadcast(flatMapToPair2.collect()));
                flatMapToPair = flatMapToPair3.flatMapToPair(tuple22 -> {
                    return ((Map) broadcast2.getValue()).containsKey(tuple22._1) ? Iterables.transform((Iterable) ((Map) broadcast2.getValue()).get(tuple22._1), sparkElement -> {
                        return new Tuple2(tuple22._1, new Tuple2(opt(sparkElement), opt(tuple22._2)));
                    }).iterator() : Collections.singletonList(new Tuple2(tuple22._1, new Tuple2(Optional.empty(), opt(tuple22._2)))).iterator();
                });
                break;
            default:
                throw new IllegalStateException("Invalid type: " + join.getType() + ".");
        }
        return flatMapToPair.flatMap(new FlatMapFunctionWithCollector((tuple23, functionCollectorMem) -> {
            SparkElement sparkElement = (SparkElement) ((Optional) ((Tuple2) tuple23._2)._1).orNull();
            SparkElement sparkElement2 = (SparkElement) ((Optional) ((Tuple2) tuple23._2)._2).orNull();
            Window<?> window = sparkElement == null ? sparkElement2.getWindow() : sparkElement.getWindow();
            long max = Math.max(sparkElement == null ? window.maxTimestamp() - 1 : sparkElement.getTimestamp(), sparkElement2 == null ? window.maxTimestamp() - 1 : sparkElement2.getTimestamp());
            functionCollectorMem.clear();
            functionCollectorMem.setWindow(window);
            join.getJoiner().apply(sparkElement == null ? null : sparkElement.getElement(), sparkElement2 == null ? null : sparkElement2.getElement(), functionCollectorMem);
            return Iterators.transform(functionCollectorMem.getOutputIterator(), obj -> {
                return new SparkElement(window, max, Pair.of(((KeyedWindow) tuple23._1).key(), obj));
            });
        }, new LazyAccumulatorProvider(sparkExecutorContext.getAccumulatorFactory(), sparkExecutorContext.getSettings())));
    }

    private static <T> Optional<T> opt(T t) {
        return Optional.ofNullable(t);
    }

    private static Map<KeyedWindow, List<SparkElement>> toBroadcast(List<Tuple2<KeyedWindow, SparkElement>> list) {
        HashMap hashMap = new HashMap();
        list.forEach(tuple2 -> {
            ((List) hashMap.computeIfAbsent(tuple2._1, keyedWindow -> {
                return new ArrayList();
            })).add(tuple2._2);
        });
        return hashMap;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 520785000:
                if (implMethodName.equals("lambda$translate$9edaee98$1")) {
                    z = true;
                    break;
                }
                break;
            case 520785001:
                if (implMethodName.equals("lambda$translate$9edaee98$2")) {
                    z = 2;
                    break;
                }
                break;
            case 796972705:
                if (implMethodName.equals("lambda$translate$ae732cd4$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("cz/seznam/euphoria/spark/FlatMapFunctionWithCollector$InnerFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Lcz/seznam/euphoria/spark/FunctionCollectorMem;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("cz/seznam/euphoria/spark/BroadcastHashJoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Lcz/seznam/euphoria/core/client/operator/Join;Lscala/Tuple2;Lcz/seznam/euphoria/spark/FunctionCollectorMem;)Ljava/util/Iterator;")) {
                    Join join = (Join) serializedLambda.getCapturedArg(0);
                    return (tuple23, functionCollectorMem) -> {
                        SparkElement sparkElement = (SparkElement) ((Optional) ((Tuple2) tuple23._2)._1).orNull();
                        SparkElement sparkElement2 = (SparkElement) ((Optional) ((Tuple2) tuple23._2)._2).orNull();
                        Window window = sparkElement == null ? sparkElement2.getWindow() : sparkElement.getWindow();
                        long max = Math.max(sparkElement == null ? window.maxTimestamp() - 1 : sparkElement.getTimestamp(), sparkElement2 == null ? window.maxTimestamp() - 1 : sparkElement2.getTimestamp());
                        functionCollectorMem.clear();
                        functionCollectorMem.setWindow(window);
                        join.getJoiner().apply(sparkElement == null ? null : sparkElement.getElement(), sparkElement2 == null ? null : sparkElement2.getElement(), functionCollectorMem);
                        return Iterators.transform(functionCollectorMem.getOutputIterator(), obj -> {
                            return new SparkElement(window, max, Pair.of(((KeyedWindow) tuple23._1).key(), obj));
                        });
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("cz/seznam/euphoria/spark/BroadcastHashJoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/broadcast/Broadcast;Lscala/Tuple2;)Ljava/util/Iterator;")) {
                    Broadcast broadcast = (Broadcast) serializedLambda.getCapturedArg(0);
                    return tuple2 -> {
                        return ((Map) broadcast.getValue()).containsKey(tuple2._1) ? Iterables.transform((Iterable) ((Map) broadcast.getValue()).get(tuple2._1), sparkElement -> {
                            return new Tuple2(tuple2._1, new Tuple2(opt(tuple2._2), opt(sparkElement)));
                        }).iterator() : Collections.singletonList(new Tuple2(tuple2._1, new Tuple2(opt(tuple2._2), Optional.empty()))).iterator();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("cz/seznam/euphoria/spark/BroadcastHashJoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/broadcast/Broadcast;Lscala/Tuple2;)Ljava/util/Iterator;")) {
                    Broadcast broadcast2 = (Broadcast) serializedLambda.getCapturedArg(0);
                    return tuple22 -> {
                        return ((Map) broadcast2.getValue()).containsKey(tuple22._1) ? Iterables.transform((Iterable) ((Map) broadcast2.getValue()).get(tuple22._1), sparkElement -> {
                            return new Tuple2(tuple22._1, new Tuple2(opt(sparkElement), opt(tuple22._2)));
                        }).iterator() : Collections.singletonList(new Tuple2(tuple22._1, new Tuple2(Optional.empty(), opt(tuple22._2)))).iterator();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
