/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.join.stream;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.DefaultOpenContext;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.streaming.api.operators.AbstractInput;
import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
import org.apache.flink.streaming.api.operators.Input;
import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.TimestampedCollector;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.generated.MultiJoinCondition;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor;
import org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor;
import org.apache.flink.table.runtime.operators.join.stream.state.MultiJoinStateView;
import org.apache.flink.table.runtime.operators.join.stream.state.MultiJoinStateViews;
import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.types.RowKind;

public class StreamingMultiJoinOperator
extends AbstractStreamOperatorV2<RowData>
implements MultipleInputStreamOperator<RowData> {
    private static final long serialVersionUID = 1L;
    private final List<JoinInputSideSpec> inputSpecs;
    private final List<FlinkJoinType> joinTypes;
    private final List<RowType> inputTypes;
    private final long[] stateRetentionTime;
    private final List<Input<RowData>> typedInputs;
    private final MultiJoinCondition multiJoinCondition;
    private final GeneratedJoinCondition[] joinConditions;
    private final JoinKeyExtractor keyExtractor;
    private transient List<MultiJoinStateView> stateHandlers;
    private transient TimestampedCollector<RowData> collector;
    private transient List<RowData> nullRows;
    private transient JoinCondition[] instantiatedJoinConditions;

    public StreamingMultiJoinOperator(StreamOperatorParameters<RowData> parameters, List<RowType> inputTypes, List<JoinInputSideSpec> inputSpecs, List<FlinkJoinType> joinTypes, MultiJoinCondition multiJoinCondition, long[] stateRetentionTime, GeneratedJoinCondition[] joinConditions, JoinKeyExtractor keyExtractor, Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> joinAttributeMap) {
        super(parameters, inputSpecs.size());
        this.inputTypes = inputTypes;
        this.inputSpecs = inputSpecs;
        this.joinTypes = joinTypes;
        this.stateRetentionTime = stateRetentionTime;
        this.joinConditions = joinConditions;
        this.keyExtractor = keyExtractor;
        this.typedInputs = new ArrayList<Input<RowData>>(inputSpecs.size());
        this.multiJoinCondition = multiJoinCondition;
        this.initializeInputs();
    }

    public void open() throws Exception {
        super.open();
        this.initializeCollector();
        this.initializeNullRows();
        this.initializeStateHandlers();
        this.initializeJoinConditions();
    }

    public void close() throws Exception {
        this.closeConditions();
        super.close();
    }

    public void processElement(int inputId, StreamRecord<RowData> element) throws Exception {
        RowData input = (RowData)element.getValue();
        if (input == null) {
            return;
        }
        this.performMultiJoin(input, inputId);
        this.addRecordToState(input, inputId);
    }

    private void performMultiJoin(RowData input, int inputId) throws Exception {
        this.recursiveMultiJoin(0, input, inputId, null, false);
    }

    private void recursiveMultiJoin(int depth, RowData input, int inputId, RowData joinedRowData, boolean isInputRecordActive) throws Exception {
        int associationsPrevLevel;
        boolean anyMatches;
        if (this.isMaxDepth(depth)) {
            this.emitJoinedRow(input, joinedRowData);
            return;
        }
        boolean isLeftJoin = this.isLeftJoinAtDepth(depth);
        Integer associations = this.processRecords(depth, input, inputId, joinedRowData, isInputRecordActive, isLeftJoin);
        if (associations == null) {
            anyMatches = false;
            associationsPrevLevel = 0;
        } else {
            anyMatches = true;
            associationsPrevLevel = associations;
        }
        if (this.isInputLevel(depth, inputId)) {
            this.processInputRecord(depth, input, inputId, joinedRowData, associationsPrevLevel, anyMatches);
        } else if (isLeftJoin && !anyMatches && this.hasNoAssociations(depth, associationsPrevLevel)) {
            this.processWithNullPadding(depth, input, inputId, joinedRowData, isInputRecordActive);
        }
    }

    private void emitJoinedRow(RowData input, RowData joinedRowData) {
        joinedRowData.setRowKind(input.getRowKind());
        this.collector.collect((Object)joinedRowData);
    }

    private Integer processRecords(int depth, RowData input, int inputId, RowData joinedRowData, boolean isInputRecordActive, boolean isLeftJoin) throws Exception {
        boolean anyMatch = false;
        int associations = 0;
        if (StreamingMultiJoinOperator.isInnerJoin(isLeftJoin) && this.isInputLevel(depth, inputId)) {
            return null;
        }
        RowData joinKey = this.keyExtractor.getLeftSideJoinKey(depth, joinedRowData);
        Iterable<RowData> records = this.stateHandlers.get(depth).getRecords(joinKey);
        for (RowData record : records) {
            if (!this.matchesCondition(depth, joinedRowData, record)) continue;
            anyMatch = true;
            if (isLeftJoin && this.canOptimizeAssociationCounting(depth, inputId, input, associations = this.updateAssociationCount(associations, this.shouldIncrementAssociation(isInputRecordActive, input)))) {
                return associations;
            }
            if (!isInputRecordActive && this.isInputLevel(depth, inputId)) continue;
            RowData newJoinedRowData = StreamingMultiJoinOperator.newJoinedRowData(depth, joinedRowData, record);
            this.recursiveMultiJoin(depth + 1, input, inputId, newJoinedRowData, isInputRecordActive);
        }
        if (!anyMatch) {
            return null;
        }
        return associations;
    }

    private static RowData newJoinedRowData(int depth, RowData joinedRowData, RowData record) {
        Object newJoinedRowData = depth == 0 ? record : new JoinedRowData(joinedRowData, record);
        return newJoinedRowData;
    }

    private static boolean isInnerJoin(boolean isLeftJoin) {
        return !isLeftJoin;
    }

    private void processWithNullPadding(int depth, RowData input, int inputId, RowData joinedRowData, boolean isInputRecordActive) throws Exception {
        RowData newJoinedRowData = StreamingMultiJoinOperator.newJoinedRowData(depth, joinedRowData, this.nullRows.get(depth));
        this.recursiveMultiJoin(depth + 1, input, inputId, newJoinedRowData, isInputRecordActive);
    }

    private void processInputRecord(int depth, RowData input, int inputId, RowData joinedRowData, int associationsToPrevLevel, boolean anyMatch) throws Exception {
        int associations = associationsToPrevLevel;
        if (!this.matchesCondition(depth, joinedRowData, input)) {
            return;
        }
        boolean isLeftJoin = this.isLeftJoinAtDepth(depth);
        if (isLeftJoin) {
            associations = this.updateAssociationCount(associations, this.shouldIncrementAssociation(true, input));
        }
        if (this.isUpsert(input) && isLeftJoin && !anyMatch) {
            this.handleRetractBeforeInput(depth, input, inputId, joinedRowData);
        }
        RowData newJoinedRowData = StreamingMultiJoinOperator.newJoinedRowData(depth, joinedRowData, input);
        this.recursiveMultiJoin(depth + 1, input, inputId, newJoinedRowData, true);
        if (this.isRetraction(input) && isLeftJoin && this.hasNoAssociations(depth, associations)) {
            this.handleInsertAfterInput(depth, input, inputId, joinedRowData);
        }
    }

    private void handleRetractBeforeInput(int depth, RowData input, int inputId, RowData joinedRowData) throws Exception {
        JoinedRowData newJoinedRowData = new JoinedRowData(joinedRowData, this.nullRows.get(depth));
        RowKind originalKind = input.getRowKind();
        input.setRowKind(RowKind.DELETE);
        this.recursiveMultiJoin(depth + 1, input, inputId, (RowData)newJoinedRowData, true);
        input.setRowKind(originalKind);
    }

    private void handleInsertAfterInput(int depth, RowData input, int inputId, RowData joinedRowData) throws Exception {
        JoinedRowData newJoinedRowData = new JoinedRowData(joinedRowData, this.nullRows.get(depth));
        RowKind originalKind = input.getRowKind();
        input.setRowKind(RowKind.INSERT);
        this.recursiveMultiJoin(depth + 1, input, inputId, (RowData)newJoinedRowData, true);
        input.setRowKind(originalKind);
    }

    private void addRecordToState(RowData input, int inputId) throws Exception {
        RowData joinKey = this.keyExtractor.getJoinKey(input, inputId);
        if (this.isRetraction(input)) {
            this.stateHandlers.get(inputId).retractRecord(joinKey, input);
        } else {
            this.stateHandlers.get(inputId).addRecord(joinKey, input);
        }
    }

    private void initializeCollector() {
        this.collector = new TimestampedCollector(this.output);
    }

    private void initializeNullRows() {
        this.nullRows = new ArrayList<RowData>(this.inputTypes.size());
        for (RowType inputType : this.inputTypes) {
            this.nullRows.add((RowData)new GenericRowData(inputType.getFieldCount()));
        }
    }

    private void initializeStateHandlers() {
        if (!this.stateHandler.getKeyedStateStore().isPresent()) {
            throw new RuntimeException("Keyed state store not found when initializing keyed state store handlers.");
        }
        this.getRuntimeContext().setKeyedStateStore((KeyedStateStore)this.stateHandler.getKeyedStateStore().get());
        this.stateHandlers = new ArrayList<MultiJoinStateView>(this.inputSpecs.size());
        for (int i = 0; i < this.inputSpecs.size(); ++i) {
            String stateName = "multi-join-input-" + i;
            RowType joinKeyType = this.keyExtractor.getJoinKeyType(i);
            MultiJoinStateView stateView = MultiJoinStateViews.create((RuntimeContext)this.getRuntimeContext(), stateName, this.inputSpecs.get(i), joinKeyType, this.inputTypes.get(i), this.stateRetentionTime[i]);
            this.stateHandlers.add(stateView);
        }
    }

    private void initializeInputs() {
        for (int i = 0; i < this.inputSpecs.size(); ++i) {
            this.typedInputs.add(this.createInput(i + 1));
        }
    }

    private void closeConditions() throws Exception {
        if (this.multiJoinCondition != null) {
            this.multiJoinCondition.close();
        }
        if (this.instantiatedJoinConditions != null) {
            for (JoinCondition jc : this.instantiatedJoinConditions) {
                if (jc == null) continue;
                jc.close();
            }
        }
    }

    private Input<RowData> createInput(final int idx) {
        return new AbstractInput<RowData, RowData>((AbstractStreamOperatorV2)this, idx){

            public void processElement(StreamRecord<RowData> element) throws Exception {
                ((StreamingMultiJoinOperator)this.owner).processElement(idx - 1, element);
            }
        };
    }

    private boolean isUpsert(RowData row) {
        return row.getRowKind() == RowKind.INSERT || row.getRowKind() == RowKind.UPDATE_AFTER;
    }

    private boolean isRetraction(RowData row) {
        return row.getRowKind() == RowKind.DELETE || row.getRowKind() == RowKind.UPDATE_BEFORE;
    }

    private boolean isLeftJoinAtDepth(int depth) {
        return depth > 0 && this.joinTypes.get(depth) == FlinkJoinType.LEFT;
    }

    private boolean matchesCondition(int depth, RowData joinedRowData, RowData record) {
        return depth == 0 || this.instantiatedJoinConditions[depth].apply(joinedRowData, record);
    }

    private int updateAssociationCount(int currentCount, boolean isUpsert) {
        if (isUpsert) {
            return currentCount + 1;
        }
        return currentCount - 1;
    }

    private boolean shouldIncrementAssociation(boolean isInputRecordActive, RowData input) {
        return !isInputRecordActive || this.isUpsert(input);
    }

    public List<Input> getInputs() {
        ArrayList<Input> rawInputs = new ArrayList<Input>(this.typedInputs.size());
        rawInputs.addAll(this.typedInputs);
        return rawInputs;
    }

    private boolean isMaxDepth(int depth) {
        return depth == this.inputSpecs.size();
    }

    private boolean isInputLevel(int depth, int inputId) {
        return depth == inputId;
    }

    private boolean hasNoAssociations(int depth, int associationCountForPrevLevel) {
        return depth > 0 && associationCountForPrevLevel == 0;
    }

    private boolean canOptimizeAssociationCounting(int depth, int inputId, RowData input, int associationCountForPrevLevel) {
        if (depth == 0 || inputId != depth) {
            return false;
        }
        if (this.isUpsert(input)) {
            return associationCountForPrevLevel > 0;
        }
        return associationCountForPrevLevel > 1;
    }

    private void initializeJoinConditions() throws Exception {
        this.instantiatedJoinConditions = new JoinCondition[this.joinConditions.length];
        for (int i = 0; i < this.joinConditions.length; ++i) {
            if (this.joinConditions[i] == null) continue;
            JoinCondition cond = (JoinCondition)this.joinConditions[i].newInstance(this.getRuntimeContext().getUserCodeClassLoader());
            cond.setRuntimeContext((RuntimeContext)this.getRuntimeContext());
            cond.open(DefaultOpenContext.INSTANCE);
            this.instantiatedJoinConditions[i] = cond;
        }
    }
}

