/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tensorflow.conversion;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.tensorflow;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.tensorflow.conversion.DummyDeAllocator;

public class TensorflowConversion {
    private static tensorflow.Deallocator_Pointer_long_Pointer calling;
    private static TensorflowConversion INSTANCE;

    public static TensorflowConversion getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new TensorflowConversion();
        }
        return INSTANCE;
    }

    private TensorflowConversion() {
        if (calling == null) {
            calling = DummyDeAllocator.getInstance();
        }
    }

    public tensorflow.TF_Tensor tensorFromNDArray(INDArray ndArray) {
        int type;
        if (ndArray.data() == null) {
            throw new IllegalArgumentException("Unable to infer data type from null databuffer");
        }
        if (ndArray.isView() || ndArray.ordering() != 'c') {
            ndArray = ndArray.dup('c');
        }
        long[] ndShape = ndArray.shape();
        long[] tfShape = new long[ndShape.length];
        for (int i = 0; i < ndShape.length; ++i) {
            tfShape[i] = ndShape[i];
        }
        DataBuffer data = ndArray.data();
        DataBuffer.Type dataType = data.dataType();
        block1 : switch (dataType) {
            case DOUBLE: {
                type = 2;
                break;
            }
            case FLOAT: {
                type = 1;
                break;
            }
            case INT: {
                type = 3;
                break;
            }
            case HALF: {
                type = 19;
                break;
            }
            case COMPRESSED: {
                String algo;
                CompressedDataBuffer compressedData = (CompressedDataBuffer)data;
                CompressionDescriptor desc = compressedData.getCompressionDescriptor();
                switch (algo = desc.getCompressionAlgorithm()) {
                    case "FLOAT16": {
                        type = 19;
                        break block1;
                    }
                    case "INT8": {
                        type = 6;
                        break block1;
                    }
                    case "UINT8": {
                        type = 4;
                        break block1;
                    }
                    case "INT16": {
                        type = 5;
                        break block1;
                    }
                    case "UINT16": {
                        type = 17;
                        break block1;
                    }
                }
                throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
            }
            case LONG: {
                type = 9;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported data type: " + dataType);
            }
        }
        try {
            Nd4j.getAffinityManager().ensureLocation(ndArray, AffinityManager.Location.HOST);
        }
        catch (Exception e) {
            ndArray.getDouble(0L);
            data = ndArray.data();
            dataType = data.dataType();
            switch (dataType) {
                case DOUBLE: {
                    type = 2;
                    break;
                }
                case FLOAT: {
                    type = 1;
                    break;
                }
                case INT: {
                    type = 3;
                    break;
                }
                case LONG: {
                    type = 9;
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported data type: " + dataType);
                }
            }
        }
        LongPointer longPointer = new LongPointer(tfShape);
        tensorflow.TF_Tensor tf_tensor = tensorflow.TF_NewTensor((int)type, (LongPointer)longPointer, (int)tfShape.length, (Pointer)data.pointer(), (long)(data.length() * (long)data.getElementSize()), (tensorflow.Deallocator_Pointer_long_Pointer)calling, null);
        return tf_tensor;
    }

    public INDArray ndArrayFromTensor(tensorflow.TF_Tensor tensor) {
        int[] ndShape;
        int rank = tensorflow.TF_NumDims((tensorflow.TF_Tensor)tensor);
        if (rank == 0) {
            ndShape = new int[]{1};
        } else {
            ndShape = new int[rank];
            for (int i = 0; i < ndShape.length; ++i) {
                ndShape[i] = (int)tensorflow.TF_Dim((tensorflow.TF_Tensor)tensor, (int)i);
            }
        }
        int tfType = tensorflow.TF_TensorType((tensorflow.TF_Tensor)tensor);
        DataBuffer.Type nd4jType = this.typeFor(tfType);
        int length = ArrayUtil.prod((int[])ndShape);
        Pointer pointer = tensorflow.TF_TensorData((tensorflow.TF_Tensor)tensor).capacity((long)length);
        Indexer indexer = this.indexerForType(nd4jType, pointer);
        DataBuffer d = Nd4j.createBuffer((Pointer)indexer.pointer(), (DataBuffer.Type)nd4jType, (long)length, (Indexer)indexer);
        INDArray array = Nd4j.create((DataBuffer)d, (int[])ndShape);
        Nd4j.getAffinityManager().tagLocation(array, AffinityManager.Location.HOST);
        return array;
    }

    private Indexer indexerForType(DataBuffer.Type type, Pointer pointer) {
        switch (type) {
            case DOUBLE: {
                return DoubleIndexer.create((DoublePointer)new DoublePointer(pointer));
            }
            case FLOAT: {
                return FloatIndexer.create((FloatPointer)new FloatPointer(pointer));
            }
            case INT: {
                return IntIndexer.create((IntPointer)new IntPointer(pointer));
            }
            case LONG: {
                return LongIndexer.create((LongPointer)new LongPointer(pointer));
            }
        }
        throw new IllegalArgumentException("Illegal type " + type);
    }

    private DataBuffer.Type typeFor(int tensorflowType) {
        switch (tensorflowType) {
            case 2: {
                return DataBuffer.Type.DOUBLE;
            }
            case 1: {
                return DataBuffer.Type.FLOAT;
            }
            case 3: {
                return DataBuffer.Type.LONG;
            }
            case 9: {
                return DataBuffer.Type.LONG;
            }
        }
        throw new IllegalArgumentException("Illegal type " + tensorflowType);
    }

    public tensorflow.TF_Graph loadGraph(String filePath) throws IOException {
        byte[] bytes = Files.readAllBytes(Paths.get(filePath, new String[0]));
        return this.loadGraph(bytes);
    }

    public static String defaultDeviceForThread() {
        Integer deviceForThread = Nd4j.getAffinityManager().getDeviceForThread(Thread.currentThread());
        String deviceName = null;
        deviceName = Nd4j.getBackend().getClass().getName().contains("JCublasBackend") ? "/device:gpu:" + deviceForThread : "/device:cpu:" + deviceForThread;
        return deviceName;
    }

    public tensorflow.TF_Graph loadGraph(byte[] content) {
        byte[] toLoad = content;
        tensorflow.TF_Buffer graph_def = tensorflow.TF_NewBufferFromString((Pointer)new BytePointer(toLoad), (long)content.length);
        tensorflow.TF_Status status = tensorflow.TF_NewStatus();
        tensorflow.TF_Graph graphC = tensorflow.TF_NewGraph();
        tensorflow.TF_ImportGraphDefOptions opts = tensorflow.TF_NewImportGraphDefOptions();
        tensorflow.TF_GraphImportGraphDef((tensorflow.TF_Graph)graphC, (tensorflow.TF_Buffer)graph_def, (tensorflow.TF_ImportGraphDefOptions)opts, (tensorflow.TF_Status)status);
        if (tensorflow.TF_GetCode((tensorflow.TF_Status)status) != 0) {
            throw new RuntimeException("ERROR: Unable to import graph " + tensorflow.TF_Message((tensorflow.TF_Status)status).getString());
        }
        tensorflow.TF_DeleteImportGraphDefOptions((tensorflow.TF_ImportGraphDefOptions)opts);
        tensorflow.TF_DeleteStatus((tensorflow.TF_Status)status);
        return graphC;
    }
}

