/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.mxnet.engine.MxMatrix;
import ai.djl.mxnet.engine.MxNDArrayEx;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.NativeResource;
import ai.djl.ndarray.Matrix;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.NDIndexBooleans;
import ai.djl.ndarray.index.NDIndexFullSlice;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.internal.NDFormat;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.GradReq;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Stack;
import java.util.function.Predicate;
import java.util.stream.IntStream;

public class MxNDArray
extends NativeResource
implements NDArray {
    private static final int MAX_SIZE = 100;
    private static final int MAX_DEPTH = 10;
    private static final int MAX_ROWS = 10;
    private static final int MAX_COLUMNS = 20;
    private String name;
    private Device device;
    private SparseFormat sparseFormat;
    private DataType dataType;
    private Shape shape;
    private MxNDManager manager;
    private MxNDArrayEx mxNDArrayEx;
    private boolean shouldFree = true;

    MxNDArray(MxNDManager manager, Pointer handle, Device device, Shape shape, DataType dataType) {
        this(manager, handle);
        this.device = device;
        if (Arrays.stream(shape.getShape()).anyMatch(s -> s < 0L)) {
            throw new IllegalArgumentException("The shape must be >= 0");
        }
        this.shape = shape;
        this.dataType = dataType;
    }

    MxNDArray(MxNDManager manager, Pointer handle) {
        super(handle);
        this.manager = manager;
        this.mxNDArrayEx = new MxNDArrayEx(this);
    }

    MxNDArray(MxNDManager manager, Pointer handle, SparseFormat fmt) {
        this(manager, handle);
        this.sparseFormat = fmt;
    }

    public NDManager getManager() {
        return this.manager;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public DataType getDataType() {
        if (this.dataType == null) {
            this.dataType = JnaUtils.getDataType(this.getHandle());
        }
        return this.dataType;
    }

    public Device getDevice() {
        if (this.device == null) {
            this.device = JnaUtils.getDevice(this.getHandle());
        }
        return this.device;
    }

    public Shape getShape() {
        if (this.shape == null) {
            this.shape = JnaUtils.getShape(this.getHandle());
        }
        return this.shape;
    }

    public SparseFormat getSparseFormat() {
        if (this.sparseFormat == null) {
            this.sparseFormat = JnaUtils.getStorageType(this.getHandle());
        }
        return this.sparseFormat;
    }

    public void attach(NDManager manager) {
        this.detach();
        this.manager = (MxNDManager)manager;
        manager.attach(this.getUid(), (AutoCloseable)this);
    }

    public void detach() {
        this.manager.detach(this.getUid());
        this.manager = MxNDManager.getSystemManager();
    }

    public NDArray asInDevice(Device dev, boolean copy) {
        if (dev.equals((Object)this.getDevice()) && !copy) {
            return this.slice();
        }
        MxNDArray nd = this.manager.create(this.getShape(), this.getDataType(), dev);
        nd.name = this.name;
        this.copyTo(nd);
        return nd;
    }

    public NDArray asType(DataType dtype, boolean copy) {
        if (dtype.equals((Object)this.getDataType()) && !copy) {
            return this.slice();
        }
        MxNDArray nd = this.manager.create(this.getShape(), dtype, this.getDevice());
        nd.name = this.name;
        this.copyTo(nd);
        return nd;
    }

    public Matrix asMatrix() {
        if (!this.shape.isMatrix()) {
            throw new IllegalStateException("NDArray is not a matrix");
        }
        return new MxMatrix(this);
    }

    public void setShouldFree(boolean shouldFree) {
        this.shouldFree = shouldFree;
    }

    public void backward(boolean retainGraph) {
        JnaUtils.autogradBackward(new NDList(new NDArray[]{this}), retainGraph ? 1 : 0);
    }

    public void attachGradient() {
        this.attachGradient(GradReq.WRITE, null);
    }

    private void attachGradient(GradReq gradReq, SparseFormat format) {
        try (MxNDArray grad = this.createGradient(format);){
            int gradReqValue = gradReq.getValue();
            IntBuffer gradReqBuffer = IntBuffer.allocate(1);
            gradReqBuffer.put(0, gradReqValue);
            JnaUtils.autogradMarkVariables(1, this.getHandle(), gradReqBuffer, grad.getHandle());
        }
    }

    private MxNDArray createGradient(SparseFormat format) {
        if (format == null || format == SparseFormat.UNDEFINED) {
            return (MxNDArray)this.zerosLike();
        }
        return (MxNDArray)this.manager.zeros(this.getShape(), this.getDataType(), this.getDevice());
    }

    public NDArray getGradient() {
        Pointer pointer = JnaUtils.getGradient(this.getHandle());
        if (pointer == null) {
            throw new IllegalStateException("No gradient attached to this NDArray, please call array.attachGradient()on your NDArray or block.setInitializer() on your Block");
        }
        return this.manager.create(pointer);
    }

    public ByteBuffer toByteBuffer() {
        Shape sh = this.getShape();
        DataType dType = this.getDataType();
        long product = sh.size();
        long len = (long)dType.getNumOfBytes() * product;
        ByteBuffer bb = this.manager.allocateDirect(Math.toIntExact(len));
        Pointer pointer = Native.getDirectBufferPointer((Buffer)bb);
        JnaUtils.syncCopyToCPU(this.getHandle(), pointer, Math.toIntExact(product));
        return bb;
    }

    public void set(Buffer data) {
        int size = data.remaining();
        DataType inputType = DataType.fromBuffer((Buffer)data);
        this.validate(inputType, size);
        if (data.isDirect()) {
            JnaUtils.syncCopyFromCPU(this.getHandle(), data, size);
            return;
        }
        int numOfBytes = inputType.getNumOfBytes();
        ByteBuffer buf = this.manager.allocateDirect(size * numOfBytes);
        switch (inputType) {
            case FLOAT32: {
                buf.asFloatBuffer().put((FloatBuffer)data);
                break;
            }
            case FLOAT64: {
                buf.asDoubleBuffer().put((DoubleBuffer)data);
                break;
            }
            case UINT8: 
            case INT8: 
            case BOOLEAN: {
                buf.put((ByteBuffer)data);
                break;
            }
            case INT32: {
                buf.asIntBuffer().put((IntBuffer)data);
                break;
            }
            case INT64: {
                buf.asLongBuffer().put((LongBuffer)data);
                break;
            }
            default: {
                throw new AssertionError((Object)"Show never happen");
            }
        }
        JnaUtils.syncCopyFromCPU(this.getHandle(), buf, size);
    }

    public void set(NDIndex index, NDArray value) {
        NDIndexFullSlice fullSlice = index.getAsFullSlice(this.getShape()).orElse(null);
        if (fullSlice != null) {
            MxOpParams params = new MxOpParams();
            params.addTupleParam("begin", fullSlice.getMin());
            params.addTupleParam("end", fullSlice.getMax());
            params.addTupleParam("step", fullSlice.getStep());
            Stack<NDArray> prepareValue = new Stack<NDArray>();
            prepareValue.add(value);
            prepareValue.add(((NDArray)prepareValue.peek()).asInDevice(this.getDevice(), false));
            Shape targetShape = fullSlice.getShape();
            while (targetShape.size() > value.size()) {
                targetShape = targetShape.slice(1);
            }
            prepareValue.add(((NDArray)prepareValue.peek()).reshape(targetShape));
            prepareValue.add(((NDArray)prepareValue.peek()).broadcast(fullSlice.getShape()));
            this.manager.invoke("_npi_slice_assign", new NDArray[]{this, (NDArray)prepareValue.peek()}, new NDArray[]{this}, params);
            for (NDArray toClean : prepareValue) {
                if (toClean == value) continue;
                toClean.close();
            }
            return;
        }
        throw new UnsupportedOperationException("set() currently supports all, fixed, and slices indices");
    }

    public void set(NDIndex index, Number value) {
        NDIndexFullSlice fullSlice = index.getAsFullSlice(this.getShape()).orElse(null);
        if (fullSlice != null) {
            MxOpParams params = new MxOpParams();
            params.addTupleParam("begin", fullSlice.getMin());
            params.addTupleParam("end", fullSlice.getMax());
            params.addTupleParam("step", fullSlice.getStep());
            params.addParam("scalar", value);
            this.manager.invoke("_npi_slice_assign_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
            return;
        }
        throw new UnsupportedOperationException("set() currently supports all, fixed, and slices indices");
    }

    public void setScalar(NDIndex index, Number value) {
        NDIndexFullSlice fullSlice = index.getAsFullSlice(this.getShape()).orElse(null);
        if (fullSlice != null) {
            if (fullSlice.getShape().size() != 1L) {
                throw new IllegalArgumentException("The provided index does not set a scalar");
            }
            this.set(index, value);
            return;
        }
        throw new UnsupportedOperationException("set() currently supports all, fixed, and slices indices");
    }

    public NDArray get(NDIndex index) {
        if (index.getRank() == 0 && this.getShape().isScalar()) {
            return this.duplicate();
        }
        List indices = index.getIndices();
        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException("get() currently didn't support more that one boolean NDArray");
            }
            return this.booleanMask(((NDIndexBooleans)indices.get(0)).getIndex());
        }
        NDIndexFullSlice fullSlice = index.getAsFullSlice(this.getShape()).orElse(null);
        if (fullSlice != null) {
            MxOpParams params = new MxOpParams();
            params.addTupleParam("begin", fullSlice.getMin());
            params.addTupleParam("end", fullSlice.getMax());
            params.addTupleParam("step", fullSlice.getStep());
            MxNDArray thisArr = this.getDataType() == DataType.BOOLEAN ? this.asType(DataType.INT32, false) : this;
            NDArray result = this.manager.invoke("_npi_slice", thisArr, params);
            if (!fullSlice.getToSqueeze().isEmpty()) {
                NDArray oldResult = result;
                result = result.squeeze(fullSlice.getToSqueeze().stream().mapToInt(i -> i).toArray());
                oldResult.close();
            }
            return this.getDataType() == DataType.BOOLEAN ? result.asType(DataType.BOOLEAN, false) : result;
        }
        throw new UnsupportedOperationException("get() currently supports all, fixed, and slices indices");
    }

    public void copyTo(NDArray ndArray) {
        if (!(ndArray instanceof MxNDArray)) {
            throw new IllegalArgumentException("Only MxNDArray is supported.");
        }
        Shape inShape = this.getShape();
        Shape destShape = ndArray.getShape();
        if (!Arrays.equals(inShape.getShape(), destShape.getShape())) {
            throw new IllegalArgumentException("shape are diff. Required: " + destShape + ", Actual " + inShape);
        }
        this.manager.invoke("_npi_copyto", new NDArray[]{this}, new NDArray[]{ndArray}, null);
    }

    /*
     * Exception decompiling
     */
    public NDArray booleanMask(NDArray index, int axis) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public NDArray zerosLike() {
        return this.manager.invoke("_np_zeros_like", this, null);
    }

    public NDArray onesLike() {
        return this.manager.invoke("_np_ones_like", this, null);
    }

    public boolean contentEquals(Number number) {
        if (number == null) {
            return false;
        }
        try (NDArray result = this.eq(number);){
            boolean bl = result.all().getBoolean(new long[0]);
            return bl;
        }
    }

    public boolean contentEquals(NDArray other) {
        if (other == null || !this.shapeEquals(other)) {
            return false;
        }
        if (this.getDataType() != other.getDataType()) {
            return false;
        }
        try (NDArray result = this.eq(other).asType(DataType.INT32, false);){
            boolean bl = result.all().getBoolean(new long[0]);
            return bl;
        }
    }

    public NDArray eq(Number other) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", other.toString());
        return this.manager.invoke("_npi_equal_scalar", this, params);
    }

    public NDArray eq(NDArray other) {
        return this.manager.invoke("_npi_equal", new NDArray[]{this, other}, null);
    }

    public NDArray neq(Number other) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", other.toString());
        return this.manager.invoke("_npi_not_equal_scalar", this, params);
    }

    public NDArray neq(NDArray other) {
        return this.manager.invoke("_npi_not_equal", new NDArray[]{this, other}, null);
    }

    public NDArray gt(Number other) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", other.toString());
        return this.manager.invoke("_npi_greater_scalar", this, params);
    }

    public NDArray gt(NDArray other) {
        return this.manager.invoke("_npi_greater", new NDArray[]{this, other}, null);
    }

    public NDArray gte(Number other) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", other.toString());
        return this.manager.invoke("_npi_greater_equal_scalar", this, params);
    }

    public NDArray gte(NDArray other) {
        return this.manager.invoke("_npi_greater_equal", new NDArray[]{this, other}, null);
    }

    public NDArray lt(Number other) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", other.toString());
        return this.manager.invoke("_npi_less_scalar", this, params);
    }

    public NDArray lt(NDArray other) {
        return this.manager.invoke("_npi_less", new NDArray[]{this, other}, null);
    }

    public NDArray lte(Number other) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", other.toString());
        return this.manager.invoke("_npi_less_equal_scalar", this, params);
    }

    public NDArray lte(NDArray other) {
        return this.manager.invoke("_npi_less_equal", new NDArray[]{this, other}, null);
    }

    public NDArray add(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_add_scalar", this, params);
    }

    public NDArray add(NDArray other) {
        return this.manager.invoke("_npi_add", new NDArray[]{this, other}, null);
    }

    public NDArray sub(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_subtract_scalar", this, params);
    }

    public NDArray sub(NDArray other) {
        return this.manager.invoke("_npi_subtract", new NDArray[]{this, other}, null);
    }

    public NDArray mul(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_multiply_scalar", this, params);
    }

    public NDArray mul(NDArray other) {
        return this.manager.invoke("_npi_multiply", new NDArray[]{this, other}, null);
    }

    public NDArray toSparse(SparseFormat fmt) {
        if (fmt == SparseFormat.DENSE) {
            throw new IllegalArgumentException("Default type is not allowed");
        }
        if (fmt == this.getSparseFormat()) {
            return this.slice();
        }
        return this.castStorage(fmt);
    }

    private NDArray castStorage(SparseFormat fmt) {
        MxOpParams params = new MxOpParams();
        params.setParam("stype", fmt.getType());
        return this.manager.invoke("cast_storage", this, params);
    }

    public NDArray div(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_true_divide_scalar", this, params);
    }

    public NDArray div(NDArray other) {
        return this.manager.invoke("_npi_true_divide", new NDArray[]{this, other}, null);
    }

    public NDArray mod(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_mod_scalar", this, params);
    }

    public NDArray mod(NDArray other) {
        return this.manager.invoke("_npi_mod", new NDArray[]{this, other}, null);
    }

    public NDArray pow(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_power_scalar", this, params);
    }

    public NDArray pow(NDArray other) {
        return this.manager.invoke("_npi_power", new NDArray[]{this, other}, null);
    }

    public NDArray addi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.manager.invoke("_npi_add_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
        return this;
    }

    public NDArray addi(NDArray other) {
        this.manager.invoke("_npi_add", new NDArray[]{this, other}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray subi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.manager.invoke("_npi_subtract_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
        return this;
    }

    public NDArray subi(NDArray other) {
        this.manager.invoke("_npi_subtract", new NDArray[]{this, other}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray muli(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.manager.invoke("_npi_multiply_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
        return this;
    }

    public NDArray muli(NDArray other) {
        this.manager.invoke("_npi_multiply", new NDArray[]{this, other}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray divi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.manager.invoke("_npi_true_divide_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
        return this;
    }

    public NDArray divi(NDArray other) {
        this.manager.invoke("_npi_true_divide", new NDArray[]{this, other}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray modi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.manager.invoke("_npi_mod_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
        return this;
    }

    public NDArray modi(NDArray other) {
        this.manager.invoke("_npi_mod", new NDArray[]{this, other}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray powi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.manager.invoke("_npi_power_scalar", new NDArray[]{this}, new NDArray[]{this}, params);
        return this;
    }

    public NDArray powi(NDArray other) {
        this.manager.invoke("_npi_power", new NDArray[]{this, other}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray neg() {
        return this.manager.invoke("_npi_negative", this, null);
    }

    public NDArray negi() {
        this.manager.invoke("_npi_negative", new NDArray[]{this}, new NDArray[]{this}, null);
        return this;
    }

    public NDArray abs() {
        return this.manager.invoke("_npi_absolute", this, null);
    }

    public NDArray square() {
        return this.manager.invoke("_npi_square", this, null);
    }

    public NDArray cbrt() {
        return this.manager.invoke("_npi_cbrt", this, null);
    }

    public NDArray floor() {
        return this.manager.invoke("_npi_floor", this, null);
    }

    public NDArray ceil() {
        return this.manager.invoke("_npi_ceil", this, null);
    }

    public NDArray round() {
        return this.manager.invoke("round", this, null);
    }

    public NDArray trunc() {
        return this.manager.invoke("_npi_trunc", this, null);
    }

    public NDArray exp() {
        return this.manager.invoke("_npi_exp", this, null);
    }

    public NDArray log() {
        return this.manager.invoke("_npi_log", this, null);
    }

    public NDArray log10() {
        return this.manager.invoke("_npi_log10", this, null);
    }

    public NDArray log2() {
        return this.manager.invoke("_npi_log2", this, null);
    }

    public NDArray sin() {
        return this.manager.invoke("_npi_sin", this, null);
    }

    public NDArray cos() {
        return this.manager.invoke("_npi_cos", this, null);
    }

    public NDArray tan() {
        return this.manager.invoke("_npi_tan", this, null);
    }

    public NDArray asin() {
        return this.manager.invoke("_npi_arcsin", this, null);
    }

    public NDArray acos() {
        return this.manager.invoke("_npi_arccos", this, null);
    }

    public NDArray atan() {
        return this.manager.invoke("_npi_arctan", this, null);
    }

    public NDArray sinh() {
        return this.manager.invoke("_npi_sinh", this, null);
    }

    public NDArray cosh() {
        return this.manager.invoke("_npi_cosh", this, null);
    }

    public NDArray tanh() {
        return this.manager.invoke("_npi_tanh", this, null);
    }

    public NDArray asinh() {
        return this.manager.invoke("_npi_arcsinh", this, null);
    }

    public NDArray acosh() {
        return this.manager.invoke("_npi_arccosh", this, null);
    }

    public NDArray atanh() {
        return this.manager.invoke("_npi_arctanh", this, null);
    }

    public NDArray toDegrees() {
        return this.manager.invoke("_npi_degrees", this, null);
    }

    public NDArray toRadians() {
        return this.manager.invoke("_npi_radians", this, null);
    }

    public NDArray maximum(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_maximum_scalar", this, params);
    }

    public NDArray maximum(NDArray other) {
        return this.manager.invoke("_npi_maximum", new NDArray[]{this, other}, null);
    }

    public NDArray minimum(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.manager.invoke("_npi_minimum_scalar", this, params);
    }

    public NDArray minimum(NDArray other) {
        return this.manager.invoke("_npi_minimum", new NDArray[]{this, other}, null);
    }

    public NDArray max() {
        return this.manager.invoke("_np_max", this, null);
    }

    public NDArray max(int[] axes) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        return this.manager.invoke("_np_max", this, params);
    }

    public NDArray max(int[] axes, boolean keepDims) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        params.addParam("keepdims", keepDims);
        return this.manager.invoke("_np_max", this, params);
    }

    public NDArray min() {
        return this.manager.invoke("_np_min", this, null);
    }

    public NDArray min(int[] axes, boolean keepDims) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        params.addParam("keepdims", keepDims);
        return this.manager.invoke("_np_min", this, params);
    }

    public NDArray sum() {
        return this.manager.invoke("_np_sum", this, null);
    }

    public NDArray sum(int[] axes, boolean keepDims) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        params.addParam("keepdims", keepDims);
        return this.manager.invoke("_np_sum", this, params);
    }

    public NDArray prod() {
        return this.manager.invoke("_np_prod", this, null);
    }

    public NDArray prod(int[] axes, boolean keepDims) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        params.addParam("keepdims", keepDims);
        return this.manager.invoke("_np_prod", this, params);
    }

    public NDArray mean() {
        return this.manager.invoke("_npi_mean", this, null);
    }

    public NDArray mean(int[] axes, boolean keepDims) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        params.addParam("keepdims", keepDims);
        return this.manager.invoke("_npi_mean", this, params);
    }

    public NDArray trace(int offset, int axis1, int axis2) {
        MxOpParams params = new MxOpParams();
        params.addParam("offset", offset);
        params.addParam("axis1", axis1);
        params.addParam("axis2", axis2);
        return this.manager.invoke("_np_trace", this, params);
    }

    public NDList split(int[] indices, int axis) {
        MxOpParams params = new MxOpParams();
        if (indices[0] != 0) {
            int[] tempIndices = new int[indices.length + 1];
            tempIndices[0] = 0;
            System.arraycopy(indices, 0, tempIndices, 1, indices.length);
            indices = tempIndices;
        }
        params.addTupleParam("indices", indices);
        params.addParam("axis", axis);
        params.addParam("squeeze_axis", false);
        return this.manager.invoke("_npi_split", new NDList(new NDArray[]{this}), params);
    }

    public NDArray flatten() {
        return this.reshape(new Shape(new long[]{Math.toIntExact(this.size())}));
    }

    public NDArray reshape(Shape shape) {
        MxOpParams params = new MxOpParams();
        params.addParam("newshape", shape);
        return this.manager.invoke("_np_reshape", this, params);
    }

    public NDArray expandDims(int axis) {
        if (this.isScalar()) {
            return this.reshape(new long[]{1L});
        }
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        return this.manager.invoke("_npi_expand_dims", this, params);
    }

    public NDArray squeeze() {
        return this.manager.invoke("_np_squeeze", this, null);
    }

    public NDArray squeeze(int[] axes) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axis", axes);
        return this.manager.invoke("_np_squeeze", this, params);
    }

    public NDArray logicalAnd(NDArray other) {
        MxNDArray thisArr = this.getDataType() == DataType.BOOLEAN ? this.asType(DataType.INT32, false) : this;
        other = other.getDataType() == DataType.BOOLEAN ? other.asType(DataType.INT32, false) : other;
        return this.manager.invoke("broadcast_logical_and", new NDArray[]{thisArr, other}, null).asType(DataType.BOOLEAN, false);
    }

    public NDArray logicalOr(NDArray other) {
        MxNDArray thisArr = this.getDataType() == DataType.BOOLEAN ? this.asType(DataType.INT32, false) : this;
        other = other.getDataType() == DataType.BOOLEAN ? other.asType(DataType.INT32, false) : other;
        return this.manager.invoke("broadcast_logical_or", new NDArray[]{thisArr, other}, null).asType(DataType.BOOLEAN, false);
    }

    public NDArray logicalXor(NDArray other) {
        MxNDArray thisArr = this.getDataType() == DataType.BOOLEAN ? this.asType(DataType.INT32, false) : this;
        other = other.getDataType() == DataType.BOOLEAN ? other.asType(DataType.INT32, false) : other;
        return this.manager.invoke("broadcast_logical_xor", new NDArray[]{thisArr, other}, null).asType(DataType.BOOLEAN, false);
    }

    public NDArray logicalNot() {
        return this.manager.invoke("_npi_logical_not", this, null);
    }

    public NDArray argSort(int axis, boolean ascending) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        params.addParam("is_ascend", ascending);
        params.setDataType(DataType.INT32);
        return this.manager.invoke("argsort", this, params);
    }

    public NDArray sort(int axis) {
        if (this.isEmpty() || this.isScalar()) {
            long dim = this.getShape().dimension();
            if ((long)axis >= dim) {
                throw new IllegalArgumentException("axis " + axis + "is out of bounds for array of dimension " + dim);
            }
            return this.duplicate();
        }
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        return this.manager.invoke("sort", this, params);
    }

    public NDArray sort() {
        if (this.isEmpty() || this.isScalar()) {
            return this.duplicate();
        }
        return this.manager.invoke("sort", this, null);
    }

    public NDArray softmax(int[] axes) {
        return this.softmax(axes, 1.0);
    }

    public NDArray softmax(int[] axes, double temperature) {
        if (this.isEmpty()) {
            return this.getManager().create(this.getShape());
        }
        MxOpParams params = new MxOpParams();
        if (axes.length != 1) {
            long size = this.shape.size(axes);
            NDArray transposed = this.transpose(axes);
            Shape transposedShape = transposed.getShape();
            Shape sliced = transposed.getShape().slice(axes.length);
            NDArray array = transposed.reshape(new Shape(new long[]{size}).addAll(sliced));
            params.addParam("axis", 0);
            params.addParam("temperature", temperature);
            return this.manager.invoke("_npx_softmax", array, params).reshape(transposedShape).transpose(axes);
        }
        params.addParam("axis", axes[0]);
        params.addParam("temperature", temperature);
        return this.manager.invoke("_npx_softmax", this, params);
    }

    public NDArray cumSum() {
        return this.manager.invoke("_np_cumsum", this, null);
    }

    public NDArray cumSum(int axis) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        return this.manager.invoke("_np_cumsum", this, params);
    }

    public NDArray isInfinite() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public NDArray isNaN() {
        return this.manager.invoke("_npi_not_equal", new NDArray[]{this, this}, null);
    }

    public NDArray createMask(NDIndex index) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public NDArray createMask(Predicate<Number> predicate) {
        return null;
    }

    public NDArray toDense() {
        if (!this.isSparse()) {
            return this.slice();
        }
        return this.castStorage(SparseFormat.DENSE);
    }

    public NDArray tile(long repeats) {
        if (this.isEmpty()) {
            return this.duplicate();
        }
        int dim = this.isScalar() ? 1 : this.getShape().dimension();
        long[] repeatsArray = new long[dim];
        Arrays.fill(repeatsArray, repeats);
        return this.tile(repeatsArray);
    }

    public NDArray tile(int axis, long repeats) {
        if (this.isScalar()) {
            throw new IllegalArgumentException("scalar didn't support specifying axis");
        }
        long[] repeatsArray = new long[this.getShape().dimension()];
        Arrays.fill(repeatsArray, 1L);
        repeatsArray[this.withAxis((int)axis)] = repeats;
        return this.tile(repeatsArray);
    }

    public NDArray tile(long[] repeats) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("reps", repeats);
        return this.manager.invoke("_npi_tile", this, params);
    }

    public NDArray tile(Shape desiredShape) {
        return this.tile(this.repeatsToMatchShape(desiredShape));
    }

    public NDArray repeat(long repeats) {
        if (this.isEmpty()) {
            return this.duplicate();
        }
        int dim = this.isScalar() ? 1 : this.getShape().dimension();
        long[] repeatsArray = new long[dim];
        Arrays.fill(repeatsArray, repeats);
        return this.repeat(repeatsArray);
    }

    public NDArray repeat(int axis, long repeats) {
        long[] repeatsArray = new long[this.getShape().dimension()];
        Arrays.fill(repeatsArray, 1L);
        repeatsArray[this.withAxis((int)axis)] = repeats;
        return this.repeat(repeatsArray);
    }

    public NDArray repeat(long[] repeats) {
        MxNDArray array = this;
        int baseAxis = this.getShape().dimension() - repeats.length;
        for (int i = 0; i < repeats.length; ++i) {
            if (repeats[i] <= 1L) continue;
            MxNDArray previousArray = array;
            MxOpParams params = new MxOpParams();
            params.addParam("repeats", repeats[i]);
            params.addParam("axis", baseAxis + i);
            array = this.manager.invoke("_np_repeat", array, params);
            if (previousArray == this) continue;
            previousArray.close();
        }
        return array;
    }

    public NDArray repeat(Shape desiredShape) {
        return this.repeat(this.repeatsToMatchShape(desiredShape));
    }

    public NDArray dot(NDArray other) {
        return this.manager.invoke("_np_dot", new NDArray[]{this, other}, null);
    }

    public NDArray clip(Number min, Number max) {
        MxOpParams params = new MxOpParams();
        params.addParam("a_min", min);
        params.addParam("a_max", max);
        return this.manager.invoke("_npi_clip", this, params);
    }

    public NDArray swapAxes(int axis1, int axis2) {
        MxOpParams params = new MxOpParams();
        params.addParam("dim1", axis1);
        params.addParam("dim2", axis2);
        return this.manager.invoke("_npi_swapaxes", this, params);
    }

    public NDArray transpose() {
        return this.manager.invoke("_np_transpose", this, null);
    }

    public NDArray transpose(int ... dimensions) {
        if (Arrays.stream(dimensions).anyMatch(d -> d < 0)) {
            throw new UnsupportedOperationException("Passing -1 for broadcasting the dimension is not currently supported");
        }
        if (!Arrays.equals(Arrays.stream(dimensions).sorted().toArray(), IntStream.range(0, this.getShape().dimension()).toArray())) {
            throw new IllegalArgumentException("You must include each of the dimensions from 0 until " + this.getShape().dimension());
        }
        MxOpParams params = new MxOpParams();
        params.addTupleParam("axes", dimensions);
        return this.manager.invoke("_np_transpose", this, params);
    }

    public NDArray broadcast(Shape shape) {
        MxOpParams params = new MxOpParams();
        params.setShape(shape);
        return this.manager.invoke("_np_broadcast_to", this, params);
    }

    public NDArray argMax() {
        return this.manager.invoke("_npi_argmax", this, null);
    }

    public NDArray argMax(int axis) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        return this.manager.invoke("_npi_argmax", this, params);
    }

    public NDArray argMin() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        MxNDArray array = this.isScalar() ? this.reshape(new long[]{1L}) : this;
        try (NDArray temp = this.manager.invoke("argmin", array, null);){
            NDArray nDArray = temp.reshape(new Shape(new long[0]));
            return nDArray;
        }
    }

    public NDArray argMin(int axis) {
        MxNDArray array = this.isScalar() ? this.reshape(new long[]{1L}) : this;
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        NDArray temp = this.manager.invoke("argmin", array, params);
        if (this.isScalar()) {
            NDArray res = temp.reshape(new Shape(new long[0]));
            temp.close();
            return res;
        }
        return temp;
    }

    public NDArray percentile(Number percentile) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public NDArray percentile(Number percentile, int[] dimension) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public NDArray median() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public NDArray median(int[] axes) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public NDArray nonzero() {
        MxNDArray thisArr = this.getDataType() == DataType.BOOLEAN ? this.asType(DataType.INT32, false) : this;
        return this.manager.invoke("_npx_nonzero", thisArr, null);
    }

    public NDArrayEx getNDArrayInternal() {
        return this.mxNDArrayEx;
    }

    private long[] repeatsToMatchShape(Shape desiredShape) {
        Shape curShape = this.getShape();
        int dimension = curShape.dimension();
        if (desiredShape.dimension() > dimension) {
            throw new IllegalArgumentException("The desired shape has too many dimensions");
        }
        if (desiredShape.dimension() < dimension) {
            int additionalDimensions = dimension - desiredShape.dimension();
            desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape);
        }
        long[] repeats = new long[dimension];
        for (int i = 0; i < dimension; ++i) {
            if (curShape.get(i) == 0L || desiredShape.get(i) % curShape.get(i) != 0L) {
                throw new IllegalArgumentException("The desired shape is not a multiple of the original shape");
            }
            repeats[i] = Math.round(Math.ceil((double)desiredShape.get(i) / (double)curShape.get(i)));
        }
        return repeats;
    }

    private int withAxis(int axis) {
        return Math.floorMod(axis, this.getShape().dimension());
    }

    private void validate(DataType inputType, int size) {
        if (this.getDataType() != inputType && (this.dataType != DataType.UINT8 && this.dataType != DataType.BOOLEAN || inputType != DataType.INT8)) {
            throw new IllegalStateException("DataType mismatch, required: " + this.dataType + ", actual: " + inputType);
        }
        if ((long)size != this.getShape().size()) {
            throw new IllegalArgumentException("array size (" + size + ") do not match NDArray shape: " + this.shape);
        }
    }

    public void waitToRead() {
        JnaUtils.waitToRead(this.getHandle());
    }

    public void waitToWrite() {
        JnaUtils.waitToWrite(this.getHandle());
    }

    public void waitAll() {
        JnaUtils.waitToRead(this.getHandle());
    }

    public boolean equals(Object obj) {
        if (obj instanceof MxNDArray) {
            return this.contentEquals((MxNDArray)obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    public String toString() {
        return this.toDebugString(100, 10, 10, 20);
    }

    public String toDebugString(int maxSize, int maxDepth, int maxRows, int maxColumns) {
        if (this.isReleased()) {
            return "This array is already closed";
        }
        return NDFormat.format((NDArray)this, (int)maxSize, (int)maxDepth, (int)maxRows, (int)maxColumns);
    }

    @Override
    public void close() {
        if (!this.shouldFree) {
            return;
        }
        Pointer pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            JnaUtils.freeNdArray(pointer);
            this.manager.detach(this.getUid());
            this.manager = null;
        }
    }
}

