/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.function.DoubleUnaryOperator;
import onnx.Onnx;

public class OnnxCast
extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private final Onnx.TensorProto.DataType toType;

    public OnnxCast(String modelName, String nodeName, List<IntermediateOperation> inputs, IntermediateOperation.AttributeMap attributeMap) {
        super(modelName, nodeName, inputs);
        this.attributeMap = attributeMap;
        if (attributeMap.get("to").isEmpty()) {
            throw new IllegalArgumentException("OnnxCast in " + this.name + ": Required attribute 'to' is missing.");
        }
        this.toType = Onnx.TensorProto.DataType.forNumber((int)attributeMap.get("to").get().asDouble());
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(1)) {
            return null;
        }
        return ((IntermediateOperation)this.inputs.get(0)).type().orElse(null);
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputFunctionsPresent(1)) {
            return null;
        }
        TensorFunction input = ((IntermediateOperation)this.inputs.get(0)).function().get();
        switch (this.toType) {
            case BOOL: {
                return new Map(input, (DoubleUnaryOperator)new AsBool());
            }
            case INT8: 
            case INT16: 
            case INT32: 
            case INT64: 
            case UINT8: 
            case UINT16: 
            case UINT32: 
            case UINT64: {
                return new Map(input, (DoubleUnaryOperator)new AsInt());
            }
            case FLOAT: 
            case DOUBLE: 
            case FLOAT16: {
                return input;
            }
            case STRING: {
                throw new IllegalArgumentException("OnnxCast in " + this.name + ": Casting to string is not implemented.");
            }
        }
        throw new IllegalArgumentException("OnnxCast in " + this.name + ": Unknown or undefined cast: " + this.toType.name());
    }

    @Override
    public OnnxCast withInputs(List<IntermediateOperation> inputs) {
        return new OnnxCast(this.modelName(), this.name(), inputs, this.attributeMap);
    }

    @Override
    public String operationName() {
        return "Cast";
    }

    private static class AsInt
    implements DoubleUnaryOperator {
        private AsInt() {
        }

        @Override
        public double applyAsDouble(double operand) {
            return operand < 0.0 ? Math.ceil(operand) : Math.floor(operand);
        }

        public String toString() {
            return "f(a)(if (a < 0, ceil(a), floor(a)))";
        }
    }

    private static class AsBool
    implements DoubleUnaryOperator {
        private AsBool() {
        }

        @Override
        public double applyAsDouble(double operand) {
            return operand != 0.0 ? 1.0 : 0.0;
        }

        public String toString() {
            return "f(a)(a!=0)";
        }
    }
}

