/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;

@Beta
public class Join
extends PrimitiveTensorFunction {
    private final TensorFunction argumentA;
    private final TensorFunction argumentB;
    private final DoubleBinaryOperator combinator;

    public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(combinator, "The combinator function cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.combinator = combinator;
    }

    public TensorFunction argumentA() {
        return this.argumentA;
    }

    public TensorFunction argumentB() {
        return this.argumentB;
    }

    public DoubleBinaryOperator combinator() {
        return this.combinator;
    }

    @Override
    public List<TensorFunction> functionArguments() {
        return ImmutableList.of((Object)this.argumentA, (Object)this.argumentB);
    }

    @Override
    public TensorFunction replaceArguments(List<TensorFunction> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
        }
        return new Join(arguments.get(0), arguments.get(1), this.combinator);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new Join(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.combinator);
    }

    @Override
    public String toString(ToStringContext context) {
        return "join(" + this.argumentA.toString(context) + ", " + this.argumentB.toString(context) + ", " + this.combinator + ")";
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor a = this.argumentA.evaluate(context);
        Tensor b = this.argumentB.evaluate(context);
        TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
        if (this.hasSingleIndexedDimension(a) && this.hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) {
            return this.indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
        }
        if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) {
            return this.singleSpaceJoin(a, b, joinedType);
        }
        if (a.type().dimensions().containsAll(b.type().dimensions())) {
            return this.subspaceJoin(b, a, joinedType, true);
        }
        if (b.type().dimensions().containsAll(a.type().dimensions())) {
            return this.subspaceJoin(a, b, joinedType, false);
        }
        return this.generalJoin(a, b, joinedType);
    }

    private boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
        int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
        Iterator<Double> aIterator = a.valueIterator();
        Iterator<Double> bIterator = b.valueIterator();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
        int i = 0;
        while (i < joinedLength) {
            builder.cell(this.combinator.applyAsDouble(aIterator.next(), bIterator.next()), i++);
        }
        return builder.build();
    }

    private Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType) {
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> i = a.cellIterator();
        while (i.hasNext()) {
            Map.Entry aCell = i.next();
            double bCellValue = b.get((TensorAddress)aCell.getKey());
            if (Double.isNaN(bCellValue)) continue;
            builder.cell((TensorAddress)aCell.getKey(), this.combinator.applyAsDouble((Double)aCell.getValue(), bCellValue));
        }
        return builder.build();
    }

    private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
        if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) {
            return this.indexedSubspaceJoin((IndexedTensor)subspace, (IndexedTensor)superspace, joinedType, reversedArgumentOrder);
        }
        return this.generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
    }

    private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
        if (subspace.size() == 0 || superspace.size() == 0) {
            return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
        }
        DimensionSizes joinedSizes = this.joinedSize(joinedType, subspace, superspace);
        IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
        HashSet<String> superDimensionNames = new HashSet<String>(superspace.type().dimensionNames());
        superDimensionNames.removeAll(subspace.type().dimensionNames());
        Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes);
        while (i.hasNext()) {
            IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
            this.joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder);
        }
        return builder.build();
    }

    private void joinSubspaces(Iterator<Double> subspace, int subspaceSize, Iterator<Tensor.Cell> superspace, int superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
        int joinedLength = Math.min(subspaceSize, superspaceSize);
        if (reversedArgumentOrder) {
            for (int i = 0; i < joinedLength; ++i) {
                Tensor.Cell supercell = superspace.next();
                builder.cell(supercell, this.combinator.applyAsDouble(supercell.getValue(), subspace.next()));
            }
        } else {
            for (int i = 0; i < joinedLength; ++i) {
                Tensor.Cell supercell = superspace.next();
                builder.cell(supercell, this.combinator.applyAsDouble(subspace.next(), supercell.getValue()));
            }
        }
    }

    private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); ++i) {
            String dimensionName = joinedType.dimensions().get(i).name();
            Optional<Integer> aIndex = a.type().indexOfDimension(dimensionName);
            Optional<Integer> bIndex = b.type().indexOfDimension(dimensionName);
            if (aIndex.isPresent() && bIndex.isPresent()) {
                builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get())));
                continue;
            }
            if (aIndex.isPresent()) {
                builder.set(i, a.dimensionSizes().size(aIndex.get()));
                continue;
            }
            if (!bIndex.isPresent()) continue;
            builder.set(i, b.dimensionSizes().size(bIndex.get()));
        }
        return builder.build();
    }

    private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
        int[] subspaceIndexes = this.subspaceIndexes(superspace.type(), subspace.type());
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> i = superspace.cellIterator();
        while (i.hasNext()) {
            Map.Entry supercell = i.next();
            TensorAddress subaddress = this.mapAddressToSubspace((TensorAddress)supercell.getKey(), subspaceIndexes);
            double subspaceValue = subspace.get(subaddress);
            if (Double.isNaN(subspaceValue)) continue;
            builder.cell((TensorAddress)supercell.getKey(), reversedArgumentOrder ? this.combinator.applyAsDouble((Double)supercell.getValue(), subspaceValue) : this.combinator.applyAsDouble(subspaceValue, (Double)supercell.getValue()));
        }
        return builder.build();
    }

    private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
        int[] subspaceIndexes = new int[subtype.dimensions().size()];
        for (int i = 0; i < subtype.dimensions().size(); ++i) {
            subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
        }
        return subspaceIndexes;
    }

    private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
        String[] subspaceLabels = new String[subspaceIndexes.length];
        for (int i = 0; i < subspaceIndexes.length; ++i) {
            subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
        }
        return TensorAddress.of(subspaceLabels);
    }

    private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) {
        if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
            return this.indexedGeneralJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
        }
        return this.mappedHashJoin(a, b, joinedType);
    }

    private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
        DimensionSizes joinedSize = this.joinedSize(joinedType, a, b);
        Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
        int[] aToIndexes = this.mapIndexes(a.type(), joinedType);
        int[] bToIndexes = this.mapIndexes(b.type(), joinedType);
        this.joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
        this.joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
        return builder.build();
    }

    private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) {
        Sets.SetView sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
        Sets.SetView dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
        DimensionSizes aIterateSize = this.joinedSizeOf(a.type(), joinedType, joinedSize);
        DimensionSizes bIterateSize = this.joinedSizeOf(b.type(), joinedType, joinedSize);
        Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator((Set<String>)dimensionsOnlyInA, aIterateSize);
        while (ia.hasNext()) {
            IndexedTensor.SubspaceIterator aSubspace = ia.next();
            while (aSubspace.hasNext()) {
                Tensor.Cell aCell = aSubspace.next();
                PartialAddress matchingBCells = this.partialAddress(a.type(), aSubspace.address(), (Set<String>)sharedDimensions);
                IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize);
                while (bSubspace.hasNext()) {
                    Tensor.Cell bCell = bSubspace.next();
                    TensorAddress joinedAddress = this.joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
                    double joinedValue = reversedOrder ? this.combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) : this.combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
                    builder.cell(joinedAddress, joinedValue);
                }
            }
        }
    }

    private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
        PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
        for (int i = 0; i < addressType.dimensions().size(); ++i) {
            if (!retainDimensions.contains(addressType.dimensions().get(i).name())) continue;
            builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
        }
        return builder.build();
    }

    private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
        int dimensionIndex = 0;
        for (int i = 0; i < joinedType.dimensions().size(); ++i) {
            if (!type.dimensionNames().contains(joinedType.dimensions().get(i).name())) continue;
            builder.set(dimensionIndex++, joinedSizes.size(i));
        }
        return builder.build();
    }

    private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
        int[] aToIndexes = this.mapIndexes(a.type(), joinedType);
        int[] bToIndexes = this.mapIndexes(b.type(), joinedType);
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> aIterator = a.cellIterator();
        while (aIterator.hasNext()) {
            Map.Entry aCell = aIterator.next();
            Iterator<Tensor.Cell> bIterator = b.cellIterator();
            while (bIterator.hasNext()) {
                Map.Entry bCell = bIterator.next();
                TensorAddress combinedAddress = this.joinAddresses((TensorAddress)aCell.getKey(), aToIndexes, (TensorAddress)bCell.getKey(), bToIndexes, joinedType);
                if (combinedAddress == null) continue;
                builder.cell(combinedAddress, this.combinator.applyAsDouble((Double)aCell.getValue(), (Double)bCell.getValue()));
            }
        }
        return builder.build();
    }

    private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) {
        boolean swapTensors;
        TensorType commonDimensionType = this.commonDimensions(a, b);
        if (commonDimensionType.dimensions().isEmpty()) {
            return this.mappedGeneralJoin(a, b, joinedType);
        }
        boolean bl = swapTensors = a.size() > b.size();
        if (swapTensors) {
            Tensor temp = a;
            a = b;
            b = temp;
        }
        int[] aIndexesInCommon = this.mapIndexes(commonDimensionType, a.type());
        int[] bIndexesInCommon = this.mapIndexes(commonDimensionType, b.type());
        int[] aIndexesInJoined = this.mapIndexes(a.type(), joinedType);
        int[] bIndexesInJoined = this.mapIndexes(b.type(), joinedType);
        HashMap<TensorAddress, List<Object>> aCellsByCommonAddress = new HashMap<TensorAddress, List<Object>>();
        Iterator<Tensor.Cell> cellIterator = a.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell aCell = cellIterator.next();
            TensorAddress partialCommonAddress = this.partialCommonAddress(aCell, aIndexesInCommon);
            aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList());
            ((List)aCellsByCommonAddress.get(partialCommonAddress)).add(aCell);
        }
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> cellIterator2 = b.cellIterator();
        while (cellIterator2.hasNext()) {
            Tensor.Cell bCell = cellIterator2.next();
            TensorAddress partialCommonAddress = this.partialCommonAddress(bCell, bIndexesInCommon);
            for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) {
                TensorAddress combinedAddress = this.joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType);
                if (combinedAddress == null) continue;
                double combinedValue = swapTensors ? this.combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) : this.combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
                builder.cell(combinedAddress, combinedValue);
            }
        }
        return builder.build();
    }

    private int[] mapIndexes(TensorType fromType, TensorType toType) {
        int[] toIndexes = new int[fromType.dimensions().size()];
        for (int i = 0; i < fromType.dimensions().size(); ++i) {
            toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
        }
        return toIndexes;
    }

    private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) {
        String[] joinedLabels = new String[joinedType.dimensions().size()];
        this.mapContent(a, joinedLabels, aToIndexes);
        boolean compatible = this.mapContent(b, joinedLabels, bToIndexes);
        if (!compatible) {
            return null;
        }
        return TensorAddress.of(joinedLabels);
    }

    private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
        for (int i = 0; i < from.size(); ++i) {
            int toIndex = indexMap[i];
            if (to[toIndex] != null && !to[toIndex].equals(from.label(i))) {
                return false;
            }
            to[toIndex] = from.label(i);
        }
        return true;
    }

    private TensorType commonDimensions(Tensor a, Tensor b) {
        TensorType.Builder typeBuilder = new TensorType.Builder();
        TensorType aType = a.type();
        TensorType bType = b.type();
        for (int i = 0; i < aType.dimensions().size(); ++i) {
            TensorType.Dimension aDim = aType.dimensions().get(i);
            for (int j = 0; j < bType.dimensions().size(); ++j) {
                TensorType.Dimension bDim = bType.dimensions().get(j);
                if (!aDim.equals(bDim)) continue;
                typeBuilder.set(bDim);
            }
        }
        return typeBuilder.build();
    }

    private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
        TensorAddress address = cell.getKey();
        String[] labels = new String[indexMap.length];
        for (int i = 0; i < labels.length; ++i) {
            labels[i] = address.label(indexMap[i]);
        }
        return TensorAddress.of(labels);
    }
}

