/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.python4j.numpy;

import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.cpython.PyTypeObject;
import org.bytedeco.cpython.global.python;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.python4j.PythonConstants;
import org.nd4j.python4j.PythonContextManager;
import org.nd4j.python4j.PythonException;
import org.nd4j.python4j.PythonExecutioner;
import org.nd4j.python4j.PythonGIL;
import org.nd4j.python4j.PythonObject;
import org.nd4j.python4j.PythonType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NumpyArray
extends PythonType<INDArray> {
    private static final Logger log = LoggerFactory.getLogger(NumpyArray.class);
    public static final NumpyArray INSTANCE;
    private static final AtomicBoolean init;
    private static final Map<String, DataBuffer> cache;
    public static final String IMPORT_NUMPY_ARRAY = "org.eclipse.python4j.numpyimport";
    public static final String ADD_JAVACPP_NUMPY_TO_PATH = "org.eclipse.python4j.numpyimport";
    public static final String DEFAULT_IMPORT_NUMPY_ARRAY = "true";
    public static final String DEFAULT_ADD_JAVACPP_NUMPY_TO_PATH = "true";

    public File[] packages() {
        try {
            return new File[]{numpy.cachePackage()};
        }
        catch (Exception e) {
            throw new PythonException((Throwable)e);
        }
    }

    public synchronized void init() {
        if (init.get()) {
            return;
        }
        init.set(true);
        if (Boolean.parseBoolean(System.getProperty("org.eclipse.python4j.numpyimport", "true"))) {
            if (Boolean.parseBoolean(System.getProperty("org.eclipse.python4j.numpyimport", "true"))) {
                python.Py_AddPath((File[])numpy.cachePackages());
            }
            PythonConstants.setInitializePython((boolean)false);
            python.Py_Initialize();
            int err = numpy._import_array();
            if (err < 0) {
                System.out.println("Numpy import failed!");
                throw new PythonException("Numpy import failed!");
            }
        }
        if (PythonGIL.locked()) {
            throw new PythonException("Can not initialize numpy - GIL already acquired.");
        }
    }

    public NumpyArray() {
        super("numpy.ndarray", INDArray.class);
    }

    public INDArray toJava(PythonObject pythonObject) {
        DataType dtype;
        log.debug("Converting PythonObject to INDArray...");
        PyObject np = python.PyImport_ImportModule((String)"numpy");
        PyObject ndarray = python.PyObject_GetAttrString((PyObject)np, (String)"ndarray");
        if (python.PyObject_IsInstance((PyObject)pythonObject.getNativePythonObject(), (PyObject)ndarray) != 1) {
            python.Py_DecRef((PyObject)ndarray);
            python.Py_DecRef((PyObject)np);
            throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
        }
        python.Py_DecRef((PyObject)ndarray);
        python.Py_DecRef((PyObject)np);
        PyArrayObject npArr = new PyArrayObject((Pointer)pythonObject.getNativePythonObject());
        long[] shape = new long[numpy.PyArray_NDIM((PyArrayObject)npArr)];
        SizeTPointer shapePtr = numpy.PyArray_SHAPE((PyArrayObject)npArr);
        if (shapePtr != null) {
            shapePtr.get(shape, 0, shape.length);
        }
        long[] strides = new long[shape.length];
        SizeTPointer stridesPtr = numpy.PyArray_STRIDES((PyArrayObject)npArr);
        if (stridesPtr != null) {
            stridesPtr.get(strides, 0, strides.length);
        }
        int npdtype = numpy.PyArray_TYPE((PyArrayObject)npArr);
        switch (npdtype) {
            case 12: {
                dtype = DataType.DOUBLE;
                break;
            }
            case 11: {
                dtype = DataType.FLOAT;
                break;
            }
            case 3: {
                dtype = DataType.SHORT;
                break;
            }
            case 5: {
                dtype = DataType.INT32;
                break;
            }
            case 7: {
                dtype = DataType.INT64;
                break;
            }
            case 6: {
                dtype = DataType.UINT32;
                break;
            }
            case 1: {
                dtype = DataType.INT8;
                break;
            }
            case 2: {
                dtype = DataType.UINT8;
                break;
            }
            case 0: {
                dtype = DataType.BOOL;
                break;
            }
            case 23: {
                dtype = DataType.FLOAT16;
                break;
            }
            case 9: {
                dtype = DataType.INT64;
                break;
            }
            case 4: {
                dtype = DataType.UINT16;
                break;
            }
            case 8: 
            case 10: {
                dtype = DataType.UINT64;
                break;
            }
            default: {
                throw new PythonException("Unsupported array data type: " + npdtype);
            }
        }
        long size = 1L;
        int i = 0;
        while (i < shape.length) {
            size *= shape[i++];
        }
        long address = numpy.PyArray_DATA((PyArrayObject)npArr).address();
        String key = address + "_" + size + "_" + dtype;
        DataBuffer buff = cache.get(key);
        if (buff == null) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
                ptr = ptr.limit(size);
                ptr = ptr.capacity(size);
                buff = Nd4j.createBuffer((Pointer)ptr, (long)size, (DataType)dtype);
                cache.put(key, buff);
            }
        }
        int elemSize = buff.getElementSize();
        long[] nd4jStrides = new long[strides.length];
        for (int i2 = 0; i2 < strides.length; ++i2) {
            nd4jStrides[i2] = strides[i2] / (long)elemSize;
        }
        INDArray ret = Nd4j.create((DataBuffer)buff, (long[])shape, (long[])nd4jStrides, (long)0L, (char)Shape.getOrder((long[])shape, (long[])nd4jStrides, (long)1L), (DataType)dtype);
        Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
        log.debug("Done creating numpy array.");
        return ret;
    }

    public PythonObject toPython(INDArray indArray) {
        String ctype;
        int numpyType;
        log.debug("Converting INDArray to PythonObject...");
        DataType dataType = indArray.dataType();
        DataBuffer buff = indArray.data();
        String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType;
        cache.put(key, buff);
        switch (dataType) {
            case DOUBLE: {
                numpyType = 12;
                ctype = "c_double";
                break;
            }
            case FLOAT: 
            case BFLOAT16: {
                numpyType = 11;
                ctype = "c_float";
                break;
            }
            case SHORT: {
                numpyType = 3;
                ctype = "c_short";
                break;
            }
            case INT: {
                numpyType = 5;
                ctype = "c_int";
                break;
            }
            case LONG: {
                numpyType = numpy.NPY_INT64;
                ctype = "c_int64";
                break;
            }
            case UINT16: {
                numpyType = 4;
                ctype = "c_uint16";
                break;
            }
            case UINT32: {
                numpyType = 6;
                ctype = "c_uint";
                break;
            }
            case UINT64: {
                numpyType = numpy.NPY_UINT64;
                ctype = "c_uint64";
                break;
            }
            case BOOL: {
                numpyType = 0;
                ctype = "c_bool";
                break;
            }
            case BYTE: {
                numpyType = 1;
                ctype = "c_byte";
                break;
            }
            case UBYTE: {
                numpyType = 2;
                ctype = "c_ubyte";
                break;
            }
            case HALF: {
                numpyType = 23;
                ctype = "c_short";
                break;
            }
            default: {
                throw new RuntimeException("Unsupported dtype: " + dataType);
            }
        }
        long[] shape = indArray.shape();
        INDArray inputArray = indArray;
        if (dataType == DataType.BFLOAT16) {
            log.warn("Creating copy of array as bfloat16 is not supported by numpy.");
            inputArray = indArray.castTo(DataType.FLOAT);
        }
        Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST);
        if (!PythonConstants.releaseGilAutomatically() || PythonConstants.createNpyViaPython()) {
            try (PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter");){
                log.debug("Stringing exec...");
                String code = "import ctypes\nimport numpy as np\ncArr = (ctypes." + ctype + "*" + indArray.length() + ").from_address(" + indArray.data().pointer().address() + ")\nnpArr = np.frombuffer(cArr, dtype=" + (String)(numpyType == 23 ? "'half'" : "ctypes." + ctype) + ").reshape(" + Arrays.toString(indArray.shape()) + ")";
                PythonExecutioner.exec((String)code);
                log.debug("exec done.");
                PythonObject ret = PythonExecutioner.getVariable((String)"npArr");
                python.Py_IncRef((PyObject)ret.getNativePythonObject());
                PythonObject pythonObject = ret;
                return pythonObject;
            }
        }
        log.debug("NUMPY: PyArray_Type()");
        PyTypeObject pyTypeObject = numpy.PyArray_Type();
        log.debug("NUMPY: PyArray_New()");
        PyObject npArr = numpy.PyArray_New((PyTypeObject)pyTypeObject, (int)shape.length, (SizeTPointer)new SizeTPointer(shape), (int)numpyType, null, (Pointer)inputArray.data().addressPointer(), (int)0, (int)1281, null);
        log.debug("Created numpy array.");
        return new PythonObject(npArr);
    }

    public boolean accepts(Object javaObject) {
        return javaObject instanceof INDArray;
    }

    public INDArray adapt(Object javaObject) {
        if (javaObject instanceof INDArray) {
            return (INDArray)javaObject;
        }
        throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray");
    }

    static {
        init = new AtomicBoolean(false);
        cache = new HashMap<String, DataBuffer>();
        new PythonExecutioner();
        INSTANCE = new NumpyArray();
        INSTANCE.init();
    }
}

