/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public class Gather
extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private int axis;

    public Gather(String modelName, String nodeName, List<IntermediateOperation> inputs, IntermediateOperation.AttributeMap attributeMap) {
        super(modelName, nodeName, inputs);
        this.attributeMap = attributeMap;
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        int i;
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType dataType = ((IntermediateOperation)this.inputs.get(0)).type().get();
        OrderedTensorType indicesType = ((IntermediateOperation)this.inputs.get(1)).type().get();
        this.axis = (int)this.attributeMap.get("axis").orElse((Value)DoubleValue.zero).asDouble();
        if (this.axis < 0) {
            this.axis = dataType.rank() + this.axis;
        }
        OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(this.resultValueType());
        for (i = 0; i < this.axis; ++i) {
            this.addDimension(i, dataType.dimensions().get(i).size().orElse(-1L), typeBuilder);
        }
        for (i = 0; i < indicesType.rank(); ++i) {
            this.addDimension(i + this.axis, indicesType.dimensions().get(i).size().orElse(-1L), typeBuilder);
        }
        for (i = this.axis + 1; i < dataType.rank(); ++i) {
            this.addDimension(i + indicesType.rank(), dataType.dimensions().get(i).size().orElse(-1L), typeBuilder);
        }
        ((IntermediateOperation)this.inputs.get((int)0)).exportAsRankingFunction = true;
        ((IntermediateOperation)this.inputs.get((int)1)).exportAsRankingFunction = true;
        return typeBuilder.build();
    }

    private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) {
        String name = String.format("%s_%d", this.vespaName(), dimensionIndex);
        typeBuilder.add(TensorType.Dimension.indexed((String)name, (long)size));
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputFunctionsPresent(2)) {
            return null;
        }
        IntermediateOperation data = (IntermediateOperation)this.inputs.get(0);
        IntermediateOperation indices = (IntermediateOperation)this.inputs.get(1);
        OrderedTensorType dataType = data.type().get();
        OrderedTensorType indicesType = indices.type().get();
        String dataFunctionName = data.rankingExpressionFunctionName();
        String indicesFunctionName = indices.rankingExpressionFunctionName();
        ArrayList<Slice.DimensionValue<Reference>> dataSliceDimensions = new ArrayList<Slice.DimensionValue<Reference>>();
        for (int i = 0; i < this.axis; ++i) {
            this.addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i);
        }
        ArrayList<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<Slice.DimensionValue<Reference>>();
        for (int i = 0; i < indicesType.rank(); ++i) {
            this.addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), this.axis + i);
        }
        ExpressionNode sliceExpression = this.createSliceExpression(indicesSliceDimensions, indicesFunctionName);
        ExpressionNode indexExpression = this.createIndexExpression(dataType, sliceExpression);
        this.addSliceDimension(dataSliceDimensions, dataType.dimensions().get(this.axis).name(), indexExpression);
        for (int i = this.axis + 1; i < dataType.rank(); ++i) {
            this.addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1);
        }
        sliceExpression = this.createSliceExpression(dataSliceDimensions, dataFunctionName);
        return Generate.bound((TensorType)this.type.type(), (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)sliceExpression));
    }

    private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) {
        TensorFunctionNode.ExpressionTensorFunction inputIndices = new TensorFunctionNode.ExpressionTensorFunction((ExpressionNode)new ReferenceNode(referenceName));
        Slice sliceIndices = new Slice((TensorFunction)inputIndices, dimensionValues);
        return new TensorFunctionNode((TensorFunction)sliceIndices);
    }

    private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) {
        ConstantNode axisSize = new ConstantNode((Value)new DoubleValue((double)((Long)dataType.dimensions().get(this.axis).size().get()).longValue()));
        EmbracedNode plus = new EmbracedNode((ExpressionNode)new ArithmeticNode(slice, ArithmeticOperator.PLUS, (ExpressionNode)axisSize));
        ArithmeticNode mod = new ArithmeticNode((ExpressionNode)plus, ArithmeticOperator.MODULO, (ExpressionNode)axisSize);
        return mod;
    }

    private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, ExpressionNode expr) {
        dimensionValues.add((Slice.DimensionValue<Reference>)new Slice.DimensionValue(Optional.of(dimensionName), TensorFunctionNode.wrapScalar((ExpressionNode)new EmbracedNode(expr))));
    }

    private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, int dimensionIndex) {
        String outputDimensionName = this.type.dimensions().get(dimensionIndex).name();
        this.addSliceDimension(dimensionValues, dimensionName, (ExpressionNode)new ReferenceNode(outputDimensionName));
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        int i;
        if (!this.allInputTypesPresent(2)) {
            return;
        }
        for (int i2 = 0; i2 < this.type.dimensions().size(); ++i2) {
            renamer.addDimension(this.type.dimensions().get(i2).name());
            for (int j = i2 + 1; j < this.type.dimensions().size(); ++j) {
                renamer.addConstraint(this.type.dimensions().get(i2).name(), this.type.dimensions().get(j).name(), DimensionRenamer.Constraint.lessThan(), this);
            }
        }
        OrderedTensorType dataType = ((IntermediateOperation)this.inputs.get(0)).type().get();
        OrderedTensorType indicesType = ((IntermediateOperation)this.inputs.get(1)).type().get();
        for (i = 0; i < this.axis; ++i) {
            renamer.addConstraint(this.type.dimensions().get(i).name(), dataType.dimensions().get(i).name(), DimensionRenamer.Constraint.equal(), this);
        }
        for (i = 0; i < indicesType.rank(); ++i) {
            renamer.addConstraint(this.type.dimensions().get(i + this.axis).name(), indicesType.dimensions().get(i).name(), DimensionRenamer.Constraint.equal(), this);
        }
        for (i = this.axis + 1; i < dataType.rank(); ++i) {
            renamer.addConstraint(this.type.dimensions().get(i + indicesType.rank() - 1).name(), dataType.dimensions().get(i).name(), DimensionRenamer.Constraint.equal(), this);
        }
    }

    @Override
    public Gather withInputs(List<IntermediateOperation> inputs) {
        return new Gather(this.modelName(), this.name(), inputs, this.attributeMap);
    }

    @Override
    public String operationName() {
        return "Gather";
    }
}

