/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.mxnet.engine.MxSparseNDArray;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.PairList;
import com.sun.jna.Pointer;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;

public class MxNDManager
extends BaseNDManager {
    private static final MxNDManager SYSTEM_MANAGER = new SystemManager();
    private static final NDArray[] EMPTY = new NDArray[0];

    private MxNDManager(NDManager parent, Device device) {
        super(parent, device);
    }

    static MxNDManager getSystemManager() {
        return SYSTEM_MANAGER;
    }

    public ByteBuffer allocateDirect(int capacity) {
        return ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder());
    }

    public MxNDArray create(Pointer handle) {
        return new MxNDArray(this, handle);
    }

    public MxSparseNDArray create(Pointer handle, SparseFormat fmt) {
        return new MxSparseNDArray(this, handle, fmt);
    }

    public MxNDArray create(Shape shape, DataType dataType) {
        Pointer handle = JnaUtils.createNdArray(this.device, shape, dataType, shape.dimension(), false);
        return new MxNDArray(this, handle, this.device, shape, dataType, false);
    }

    public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
        SparseFormat fmt = SparseFormat.CSR;
        DataType dataType = DataType.fromBuffer((Buffer)data);
        MxNDArray indptrNd = this.create(new Shape(new long[]{indptr.length}), DataType.INT64);
        indptrNd.set(indptr);
        MxNDArray indicesNd = this.create(new Shape(new long[]{indices.length}), DataType.INT64);
        indicesNd.set(indices);
        Pointer handle = JnaUtils.createSparseNdArray(fmt, this.device, shape, dataType, new DataType[]{indptrNd.getDataType(), indicesNd.getDataType()}, new Shape[]{indptrNd.getShape(), indicesNd.getShape()}, false);
        MxSparseNDArray sparse = this.create(handle, fmt);
        MxNDArray dataNd = this.create(new Shape(new long[]{data.remaining()}), dataType);
        dataNd.set(data);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, indptrNd, 0);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 1);
        return sparse;
    }

    public MxSparseNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
        SparseFormat fmt = SparseFormat.ROW_SPARSE;
        DataType dataType = DataType.fromBuffer((Buffer)data);
        MxNDArray indicesNd = this.create(new Shape(new long[]{indices.length}), DataType.INT64);
        indicesNd.set(indices);
        Pointer handle = JnaUtils.createSparseNdArray(fmt, this.device, shape, dataType, new DataType[]{indicesNd.getDataType()}, new Shape[]{indicesNd.getShape()}, false);
        MxSparseNDArray sparse = this.create(handle, fmt);
        MxNDArray dataNd = this.create(dataShape, dataType);
        dataNd.set(data);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 0);
        return sparse;
    }

    public NDList load(Path path) {
        return JnaUtils.loadNdArray(this, path, this.device);
    }

    public NDArray zeros(Shape shape, DataType dataType) {
        return this.fill("_npi_zeros", shape, dataType);
    }

    public NDArray ones(Shape shape, DataType dataType) {
        return this.fill("_npi_ones", shape, dataType);
    }

    public NDArray full(Shape shape, float value, DataType dataType) {
        MxOpParams params = new MxOpParams();
        params.addParam("shape", shape);
        params.addParam("value", value);
        params.setDataType(dataType);
        params.setDevice(this.device);
        return this.invoke("_npi_full", params);
    }

    public NDArray arange(float start, float stop, float step, DataType dataType) {
        MxOpParams params = new MxOpParams();
        params.addParam("start", start);
        params.addParam("stop", stop);
        params.addParam("step", step);
        if (dataType != DataType.UNKNOWN) {
            params.setDataType(dataType);
        }
        params.setDevice(this.device);
        return this.invoke("_npi_arange", params);
    }

    public NDArray eye(int rows, int cols, int k, DataType dataType) {
        MxOpParams params = new MxOpParams();
        params.addParam("N", rows);
        params.addParam("M", cols);
        params.addParam("k", k);
        params.setDataType(dataType);
        params.setDevice(this.device);
        return this.invoke("_npi_eye", params);
    }

    public NDArray linspace(float start, float stop, int num, boolean endpoint) {
        if (num < 0) {
            throw new IllegalArgumentException("Num argument must be non-negative");
        }
        MxOpParams params = new MxOpParams();
        params.addParam("start", start);
        params.addParam("stop", stop);
        params.addParam("num", num);
        params.addParam("endpoint", endpoint);
        params.setDevice(this.device);
        return this.invoke("_npi_linspace", params);
    }

    public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
        MxOpParams params = new MxOpParams();
        params.addParam("low", low);
        params.addParam("high", high);
        params.addParam("size", shape);
        params.setDevice(this.device);
        if (dataType != DataType.UNKNOWN) {
            params.setDataType(dataType);
        }
        return this.invoke("_npi_uniform", params);
    }

    public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) {
        MxOpParams params = new MxOpParams();
        params.addParam("loc", loc);
        params.addParam("scale", scale);
        params.addParam("size", shape);
        params.setDevice(this.device);
        if (dataType != DataType.UNKNOWN) {
            params.setDataType(dataType);
        }
        return this.invoke("_npi_normal", params);
    }

    public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) {
        MxOpParams params = new MxOpParams();
        params.addParam("n", n);
        params.addParam("size", shape);
        return this.invoke("_npi_multinomial", pValues, params);
    }

    public NDArray randomMultinomial(int n, NDArray pValues) {
        MxOpParams params = new MxOpParams();
        params.addParam("n", n);
        return this.invoke("_npi_multinomial", pValues, params);
    }

    public MxNDManager newSubManager() {
        return this.newSubManager(this.device);
    }

    public MxNDManager newSubManager(Device dev) {
        MxNDManager manager = new MxNDManager((NDManager)this, dev);
        this.attach(manager.uid, (AutoCloseable)((Object)manager));
        return manager;
    }

    public void invoke(String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
        JnaUtils.op(operation).invoke((NDManager)this, src, dest, params);
    }

    public NDList invoke(String operation, NDList src, PairList<String, ?> params) {
        return new NDList(JnaUtils.op(operation).invoke((NDManager)this, (NDArray[])src.toArray((Object[])EMPTY), params));
    }

    public void invoke(String operation, NDList src, NDList dest, PairList<String, ?> params) {
        this.invoke(operation, (NDArray[])src.toArray((Object[])EMPTY), (NDArray[])dest.toArray((Object[])EMPTY), params);
    }

    public NDArray invoke(String operation, NDArray[] src, PairList<String, ?> params) {
        return JnaUtils.op(operation).invoke((NDManager)this, src, params)[0];
    }

    public NDArray invoke(String operation, NDArray src, PairList<String, ?> params) {
        return this.invoke(operation, new NDArray[]{src}, params);
    }

    public NDArray invoke(String operation, PairList<String, ?> params) {
        return this.invoke(operation, EMPTY, params);
    }

    public Engine getEngine() {
        return Engine.getEngine((String)"MXNet");
    }

    private NDArray fill(String opName, Shape shape, DataType dataType) {
        MxOpParams params = new MxOpParams();
        if (shape == null) {
            throw new IllegalArgumentException("Shape is required for " + opName.substring(1));
        }
        params.addParam("shape", shape);
        params.setDevice(this.device);
        params.setDataType(dataType);
        return this.invoke(opName, params);
    }

    private static final class SystemManager
    extends MxNDManager {
        SystemManager() {
            super(null, Device.defaultDevice());
        }

        public void attach(String resourceId, AutoCloseable resource) {
        }

        public void detach(String resourceId) {
        }

        public void close() {
        }
    }
}

