/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.image.loader;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.javacpp.lept;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.AndroidNativeImageLoader;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class NativeImageLoader
extends BaseImageLoader {
    public static final String[] ALLOWED_FORMATS = new String[]{"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "PNG", "TIF", "TIFF", "EXR", "WEBP"};
    protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();

    public NativeImageLoader() {
    }

    public NativeImageLoader(int height, int width) {
        this.height = height;
        this.width = width;
    }

    public NativeImageLoader(int height, int width, int channels) {
        this.height = height;
        this.width = width;
        this.channels = channels;
    }

    public NativeImageLoader(int height, int width, int channels, boolean centerCropIfNeeded) {
        this(height, width, channels);
        this.centerCropIfNeeded = centerCropIfNeeded;
    }

    public NativeImageLoader(int height, int width, int channels, ImageTransform imageTransform) {
        this(height, width, channels);
        this.imageTransform = imageTransform;
    }

    protected NativeImageLoader(NativeImageLoader other) {
        this.height = other.height;
        this.width = other.width;
        this.channels = other.channels;
        this.centerCropIfNeeded = other.centerCropIfNeeded;
        this.imageTransform = other.imageTransform;
    }

    @Override
    public String[] getAllowedFormats() {
        return ALLOWED_FORMATS;
    }

    @Override
    public INDArray asRowVector(File f) throws IOException {
        return this.asMatrix(f).ravel();
    }

    @Override
    public INDArray asRowVector(InputStream is) throws IOException {
        return this.asMatrix(is).ravel();
    }

    public INDArray asRowVector(Object image) throws IOException {
        return this.asMatrix(image).ravel();
    }

    public INDArray asRowVector(Frame image) throws IOException {
        return this.asMatrix(image).ravel();
    }

    public INDArray asRowVector(opencv_core.Mat image) throws IOException {
        return this.asMatrix(image).ravel();
    }

    static opencv_core.Mat convert(lept.PIX pix) {
        lept.PIX pix2;
        lept.PIX tempPix = null;
        if (pix.colormap() != null) {
            tempPix = pix = (pix2 = lept.pixRemoveColormap((lept.PIX)pix, (int)2));
        } else if (pix.d() < 8) {
            pix2 = null;
            switch (pix.d()) {
                case 1: {
                    pix2 = lept.pixConvert1To8(null, (lept.PIX)pix, (byte)0, (byte)-1);
                    break;
                }
                case 2: {
                    pix2 = lept.pixConvert2To8((lept.PIX)pix, (byte)0, (byte)85, (byte)-86, (byte)-1, (int)0);
                    break;
                }
                case 4: {
                    pix2 = lept.pixConvert4To8((lept.PIX)pix, (int)0);
                    break;
                }
                default: {
                    assert (false);
                    break;
                }
            }
            tempPix = pix = pix2;
        }
        int height = pix.h();
        int width = pix.w();
        int channels = pix.d() / 8;
        opencv_core.Mat mat = new opencv_core.Mat(height, width, opencv_core.CV_8UC((int)channels), (Pointer)pix.data(), (long)(4 * pix.wpl()));
        opencv_core.Mat mat2 = new opencv_core.Mat(height, width, opencv_core.CV_8UC((int)channels));
        int[] swap = new int[]{0, channels - 1, 1, channels - 2, 2, channels - 3, 3, channels - 4};
        int[] copy = new int[]{0, 0, 1, 1, 2, 2, 3, 3};
        int[] fromTo = channels > 1 && ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN) ? swap : copy;
        opencv_core.mixChannels((opencv_core.Mat)mat, (long)1L, (opencv_core.Mat)mat2, (long)1L, (int[])fromTo, (long)Math.min(channels, fromTo.length / 2));
        if (tempPix != null) {
            lept.pixDestroy((lept.PIX)tempPix);
        }
        return mat2;
    }

    @Override
    public INDArray asMatrix(File f) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));){
            INDArray iNDArray = this.asMatrix(bis);
            return iNDArray;
        }
    }

    @Override
    public INDArray asMatrix(InputStream is) throws IOException {
        byte[] bytes = IOUtils.toByteArray((InputStream)is);
        opencv_core.Mat image = opencv_imgcodecs.imdecode((opencv_core.Mat)new opencv_core.Mat(bytes), (int)6);
        if (image == null || image.empty()) {
            lept.PIX pix = lept.pixReadMem((byte[])bytes, (long)bytes.length);
            if (pix == null) {
                throw new IOException("Could not decode image from input stream");
            }
            image = NativeImageLoader.convert(pix);
            lept.pixDestroy((lept.PIX)pix);
        }
        INDArray a = this.asMatrix(image);
        image.deallocate();
        return a;
    }

    @Override
    public Image asImageMatrix(File f) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));){
            Image image = this.asImageMatrix(bis);
            return image;
        }
    }

    @Override
    public Image asImageMatrix(InputStream is) throws IOException {
        byte[] bytes = IOUtils.toByteArray((InputStream)is);
        opencv_core.Mat image = opencv_imgcodecs.imdecode((opencv_core.Mat)new opencv_core.Mat(bytes), (int)6);
        if (image == null || image.empty()) {
            lept.PIX pix = lept.pixReadMem((byte[])bytes, (long)bytes.length);
            if (pix == null) {
                throw new IOException("Could not decode image from input stream");
            }
            image = NativeImageLoader.convert(pix);
            lept.pixDestroy((lept.PIX)pix);
        }
        INDArray a = this.asMatrix(image);
        Image i = new Image(a, image.channels(), image.rows(), image.cols());
        image.deallocate();
        return i;
    }

    public INDArray asMatrix(Object image) throws IOException {
        INDArray array = null;
        if (array == null) {
            try {
                array = new AndroidNativeImageLoader(this).asMatrix(image);
            }
            catch (NoClassDefFoundError noClassDefFoundError) {
                // empty catch block
            }
        }
        if (array == null) {
            try {
                array = new Java2DNativeImageLoader(this).asMatrix(image);
            }
            catch (NoClassDefFoundError noClassDefFoundError) {
                // empty catch block
            }
        }
        return array;
    }

    protected void fillNDArray(opencv_core.Mat image, INDArray ret) {
        FloatIndexer floatidx;
        IntIndexer intidx;
        UShortIndexer ushortidx;
        int j;
        int i;
        int k;
        UByteIndexer ubyteidx;
        FloatIndexer retidx;
        int rows = image.rows();
        int cols = image.cols();
        int channels = image.channels();
        if (ret.lengthLong() != (long)(rows * cols * channels)) {
            throw new ND4JIllegalStateException("INDArray provided to store image not equal to image: {channels: " + channels + ", rows: " + rows + ", columns: " + cols + "}");
        }
        boolean direct = !Loader.getPlatform().startsWith("android");
        Indexer idx = image.createIndexer(direct);
        Pointer pointer = ret.data().pointer();
        int[] stride = ret.stride();
        boolean done = false;
        PagedPointer pagedPointer = new PagedPointer(pointer, (long)(rows * cols * channels), ret.data().offset() * (long)Nd4j.sizeOfDataType((DataBuffer.Type)ret.data().dataType()));
        if (pointer instanceof FloatPointer) {
            retidx = FloatIndexer.create((FloatPointer)pagedPointer.asFloatPointer(), (long[])new long[]{channels, rows, cols}, (long[])new long[]{stride[0], stride[1], stride[2]}, (boolean)direct);
            if (idx instanceof UByteIndexer) {
                ubyteidx = (UByteIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (float)ubyteidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof UShortIndexer) {
                ushortidx = (UShortIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (float)ushortidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof IntIndexer) {
                intidx = (IntIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (float)intidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof FloatIndexer) {
                floatidx = (FloatIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, floatidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            }
        } else if (pointer instanceof DoublePointer) {
            retidx = DoubleIndexer.create((DoublePointer)pagedPointer.asDoublePointer(), (long[])new long[]{channels, rows, cols}, (long[])new long[]{stride[0], stride[1], stride[2]}, (boolean)direct);
            if (idx instanceof UByteIndexer) {
                ubyteidx = (UByteIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)ubyteidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof UShortIndexer) {
                ushortidx = (UShortIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)ushortidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof IntIndexer) {
                intidx = (IntIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)intidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof FloatIndexer) {
                floatidx = (FloatIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)floatidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            }
        }
        if (!done) {
            for (int k2 = 0; k2 < channels; ++k2) {
                for (int i2 = 0; i2 < rows; ++i2) {
                    for (int j2 = 0; j2 < cols; ++j2) {
                        if (channels > 1) {
                            ret.putScalar(k2, i2, j2, idx.getDouble(new long[]{i2, j2, k2}));
                            continue;
                        }
                        ret.putScalar(i2, j2, idx.getDouble(new long[]{i2, j2}));
                    }
                }
            }
        }
        image.data();
        Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
    }

    public void asMatrixView(InputStream is, INDArray view) throws IOException {
        byte[] bytes = IOUtils.toByteArray((InputStream)is);
        opencv_core.Mat image = opencv_imgcodecs.imdecode((opencv_core.Mat)new opencv_core.Mat(bytes), (int)6);
        if (image == null || image.empty()) {
            lept.PIX pix = lept.pixReadMem((byte[])bytes, (long)bytes.length);
            if (pix == null) {
                throw new IOException("Could not decode image from input stream");
            }
            image = NativeImageLoader.convert(pix);
            lept.pixDestroy((lept.PIX)pix);
        }
        if (image == null) {
            throw new RuntimeException();
        }
        this.asMatrixView(image, view);
        image.deallocate();
    }

    public void asMatrixView(File f, INDArray view) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));){
            this.asMatrixView(bis, view);
        }
    }

    public void asMatrixView(opencv_core.Mat image, INDArray view) throws IOException {
        this.transformImage(image, view);
    }

    public INDArray asMatrix(Frame image) throws IOException {
        return this.asMatrix(this.converter.convert(image));
    }

    public INDArray asMatrix(opencv_core.Mat image) throws IOException {
        INDArray ret = this.transformImage(image, null);
        return ret.reshape(ArrayUtil.combine((int[][])new int[][]{{1}, ret.shape()}));
    }

    protected INDArray transformImage(opencv_core.Mat image, INDArray ret) throws IOException {
        if (this.imageTransform != null && this.converter != null) {
            ImageWritable writable = new ImageWritable(this.converter.convert(image));
            writable = this.imageTransform.transform(writable);
            image = this.converter.convert(writable.getFrame());
        }
        opencv_core.Mat image2 = null;
        opencv_core.Mat image3 = null;
        opencv_core.Mat image4 = null;
        if (this.channels > 0 && image.channels() != this.channels) {
            int code = -1;
            block0 : switch (image.channels()) {
                case 1: {
                    switch (this.channels) {
                        case 3: {
                            code = 8;
                            break;
                        }
                        case 4: {
                            code = 9;
                        }
                    }
                    break;
                }
                case 3: {
                    switch (this.channels) {
                        case 1: {
                            code = 6;
                            break;
                        }
                        case 4: {
                            code = 2;
                        }
                    }
                    break;
                }
                case 4: {
                    switch (this.channels) {
                        case 1: {
                            code = 11;
                            break block0;
                        }
                        case 3: {
                            code = 3;
                        }
                    }
                }
            }
            if (code < 0) {
                throw new IOException("Cannot convert from " + image.channels() + " to " + this.channels + " channels.");
            }
            image2 = new opencv_core.Mat();
            opencv_imgproc.cvtColor((opencv_core.Mat)image, (opencv_core.Mat)image2, (int)code);
            image = image2;
        }
        if (this.centerCropIfNeeded) {
            image3 = this.centerCropIfNeeded(image);
            if (image3 != image) {
                image = image3;
            } else {
                image3 = null;
            }
        }
        if ((image4 = this.scalingIfNeed(image)) != image) {
            image = image4;
        } else {
            image4 = null;
        }
        if (ret == null) {
            int rows = image.rows();
            int cols = image.cols();
            int channels = image.channels();
            ret = Nd4j.create((int[])new int[]{channels, rows, cols});
        }
        this.fillNDArray(image, ret);
        image.data();
        if (image2 != null) {
            image2.deallocate();
        }
        if (image3 != null) {
            image3.deallocate();
        }
        if (image4 != null) {
            image4.deallocate();
        }
        return ret;
    }

    protected opencv_core.Mat centerCropIfNeeded(opencv_core.Mat img) {
        int x = 0;
        int y = 0;
        int height = img.rows();
        int width = img.cols();
        int diff = Math.abs(width - height) / 2;
        if (width > height) {
            x = diff;
            width -= diff;
        } else if (height > width) {
            y = diff;
            height -= diff;
        }
        return img.apply(new opencv_core.Rect(x, y, width, height));
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat image) {
        return this.scalingIfNeed(image, this.height, this.width);
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat image, int dstHeight, int dstWidth) {
        opencv_core.Mat scaled = image;
        if (dstHeight > 0 && dstWidth > 0 && (image.rows() != dstHeight || image.cols() != dstWidth)) {
            scaled = new opencv_core.Mat();
            opencv_imgproc.resize((opencv_core.Mat)image, (opencv_core.Mat)scaled, (opencv_core.Size)new opencv_core.Size(dstWidth, dstHeight));
        }
        return scaled;
    }

    public ImageWritable asWritable(File f) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));){
            ImageWritable writable;
            byte[] bytes = IOUtils.toByteArray((InputStream)bis);
            opencv_core.Mat image = opencv_imgcodecs.imdecode((opencv_core.Mat)new opencv_core.Mat(bytes), (int)6);
            if (image == null || image.empty()) {
                lept.PIX pix = lept.pixReadMem((byte[])bytes, (long)bytes.length);
                if (pix == null) {
                    throw new IOException("Could not decode image from input stream");
                }
                image = NativeImageLoader.convert(pix);
                lept.pixDestroy((lept.PIX)pix);
            }
            ImageWritable imageWritable = writable = new ImageWritable(this.converter.convert(image));
            return imageWritable;
        }
    }

    public INDArray asMatrix(ImageWritable writable) throws IOException {
        opencv_core.Mat image = this.converter.convert(writable.getFrame());
        return this.asMatrix(image);
    }

    public Frame asFrame(INDArray array) {
        return this.converter.convert(this.asMat(array));
    }

    public Frame asFrame(INDArray array, int dataType) {
        return this.converter.convert(this.asMat(array, OpenCVFrameConverter.getMatDepth((int)dataType)));
    }

    public opencv_core.Mat asMat(INDArray array) {
        return this.asMat(array, -1);
    }

    public opencv_core.Mat asMat(INDArray array, int dataType) {
        int j;
        int i;
        int k;
        FloatIndexer idx;
        FloatIndexer ptridx;
        if (array.rank() > 4 || array.rank() > 3 && array.size(0) != 1) {
            throw new UnsupportedOperationException("Only rank 3 (or rank 4 with size(0) == 1) arrays supported");
        }
        int rank = array.rank();
        int[] stride = array.stride();
        long offset = array.data().offset();
        Pointer pointer = array.data().pointer().position(offset);
        int rows = array.size(rank == 3 ? 1 : 2);
        int cols = array.size(rank == 3 ? 2 : 3);
        int channels = array.size(rank == 3 ? 0 : 1);
        boolean done = false;
        if (dataType < 0) {
            dataType = pointer instanceof DoublePointer ? 6 : 5;
        }
        opencv_core.Mat mat = new opencv_core.Mat(rows, cols, opencv_core.CV_MAKETYPE((int)dataType, (int)channels));
        boolean direct = !Loader.getPlatform().startsWith("android");
        Indexer matidx = mat.createIndexer(direct);
        Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.HOST);
        if (pointer instanceof FloatPointer && dataType == 5) {
            ptridx = FloatIndexer.create((FloatPointer)((FloatPointer)pointer), (long[])new long[]{channels, rows, cols}, (long[])new long[]{stride[rank == 3 ? 0 : 1], stride[rank == 3 ? 1 : 2], stride[rank == 3 ? 2 : 3]}, (boolean)direct);
            idx = (FloatIndexer)matidx;
            for (k = 0; k < channels; ++k) {
                for (i = 0; i < rows; ++i) {
                    for (j = 0; j < cols; ++j) {
                        idx.put((long)i, (long)j, (long)k, ptridx.get((long)k, (long)i, (long)j));
                    }
                }
            }
            done = true;
        } else if (pointer instanceof DoublePointer && dataType == 6) {
            ptridx = DoubleIndexer.create((DoublePointer)((DoublePointer)pointer), (long[])new long[]{channels, rows, cols}, (long[])new long[]{stride[rank == 3 ? 0 : 1], stride[rank == 3 ? 1 : 2], stride[rank == 3 ? 2 : 3]}, (boolean)direct);
            idx = (DoubleIndexer)matidx;
            for (k = 0; k < channels; ++k) {
                for (i = 0; i < rows; ++i) {
                    for (j = 0; j < cols; ++j) {
                        idx.put((long)i, (long)j, (long)k, ptridx.get((long)k, (long)i, (long)j));
                    }
                }
            }
            done = true;
        }
        if (!done) {
            for (int k2 = 0; k2 < channels; ++k2) {
                for (int i2 = 0; i2 < rows; ++i2) {
                    for (int j2 = 0; j2 < cols; ++j2) {
                        if (rank == 3) {
                            matidx.putDouble(new long[]{i2, j2, k2}, array.getDouble(new int[]{k2, i2, j2}));
                            continue;
                        }
                        matidx.putDouble(new long[]{i2, j2, k2}, array.getDouble(new int[]{0, k2, i2, j2}));
                    }
                }
            }
        }
        return mat;
    }
}

