/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.convolution;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.factory.Nd4j;

public class Convolution {
    private Convolution() {
    }

    public static INDArray col2im(INDArray col, int[] stride, int[] padding, int height, int width) {
        return Convolution.col2im(col, stride[0], stride[1], padding[0], padding[1], height, width);
    }

    public static INDArray col2im(INDArray col, int sy, int sx, int ph, int pw, int h, int w) {
        if (col.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        INDArray output = Nd4j.create(col.size(0), col.size(1), h, w);
        Col2Im col2Im = Col2Im.builder().inputArrays(new INDArray[]{col}).outputs(new INDArray[]{output}).conv2DConfig(Conv2DConfig.builder().sy(sy).sx(sx).dw(1).dh(1).kh(h).kw(w).ph(ph).pw(pw).build()).build();
        Nd4j.getExecutioner().exec(col2Im);
        return col2Im.outputArguments()[0];
    }

    public static INDArray col2im(INDArray col, INDArray z, int sy, int sx, int ph, int pw, int h, int w, int dh, int dw) {
        if (col.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        if (z.rank() != 4) {
            throw new IllegalArgumentException("col2im output array must be rank 4");
        }
        Col2Im col2Im = Col2Im.builder().inputArrays(new INDArray[]{col}).outputs(new INDArray[]{z}).conv2DConfig(Conv2DConfig.builder().sy(sy).sx(sx).dw(dw).dh(dh).kh(h).kw(w).ph(ph).pw(pw).build()).build();
        Nd4j.getExecutioner().exec(col2Im);
        return z;
    }

    public static INDArray im2col(INDArray img, int[] kernel, int[] stride, int[] padding) {
        Nd4j.getCompressor().autoDecompress(img);
        return Convolution.im2col(img, kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false);
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, boolean isSameMode) {
        return Convolution.im2col(img, kh, kw, sy, sx, ph, pw, 1, 1, isSameMode);
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode) {
        Nd4j.getCompressor().autoDecompress(img);
        int outH = Convolution.outputSize(img.size(2), kh, sy, ph, dh, isSameMode);
        int outW = Convolution.outputSize(img.size(3), kw, sx, pw, dw, isSameMode);
        INDArray out = Nd4j.create(new int[]{img.size(0), img.size(1), kh, kw, outH, outW}, 'c');
        return Convolution.im2col(img, kh, kw, sy, sx, ph, pw, dh, dw, isSameMode, out);
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, boolean isSameMode, INDArray out) {
        Im2col im2col = Im2col.builder().outputs(new INDArray[]{out}).inputArrays(new INDArray[]{img}).conv2DConfig(Conv2DConfig.builder().kh(kh).pw(pw).ph(ph).sy(sy).sx(sx).kw(kw).kh(kh).dw(1).dh(1).isSameMode(isSameMode).build()).build();
        Nd4j.getExecutioner().exec(im2col);
        return im2col.outputArguments()[0];
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode, INDArray out) {
        Im2col im2col = Im2col.builder().outputs(new INDArray[]{out}).inputArrays(new INDArray[]{img}).conv2DConfig(Conv2DConfig.builder().kh(kh).pw(pw).ph(ph).sy(sy).sx(sx).kw(kw).kh(kh).dw(dW).dh(dH).isSameMode(isSameMode).build()).build();
        Nd4j.getExecutioner().exec(im2col);
        return im2col.outputArguments()[0];
    }

    public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor, double extra, int virtualHeight, int virtualWidth, INDArray out) {
        Pooling2D pooling = Pooling2D.builder().arrayInputs(new INDArray[]{img}).arrayOutputs(new INDArray[]{out}).config(Pooling2DConfig.builder().dh(dh).dw(dw).extra(extra).kh(kh).kw(kw).ph(ph).pw(pw).isSameMode(isSameMode).sx(sx).sy(sy).virtualHeight(virtualHeight).virtualWidth(virtualWidth).type(type).divisor(divisor).build()).build();
        Nd4j.getExecutioner().exec(pooling);
        return out;
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int pval, boolean isSameMode) {
        int oH;
        INDArray output = null;
        if (isSameMode) {
            oH = (int)Math.ceil((float)img.size(2) * 1.0f / (float)sy);
            int oW = (int)Math.ceil((float)img.size(3) * 1.0f / (float)sx);
            output = Nd4j.createUninitialized(new int[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c');
        } else {
            oH = (img.size(2) - (kh + (kh - 1) * 0) + 2 * ph) / sy + 1;
            int oW = (img.size(3) - (kw + (kw - 1) * 0) + 2 * pw) / sx + 1;
            output = Nd4j.createUninitialized(new int[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c');
        }
        Im2col im2col = Im2col.builder().inputArrays(new INDArray[]{img}).outputs(new INDArray[]{output}).conv2DConfig(Conv2DConfig.builder().kh(kh).pw(pw).ph(ph).sy(sy).sx(sx).kw(kw).kh(kh).dw(1).dh(1).isSameMode(isSameMode).build()).build();
        Nd4j.getExecutioner().exec(im2col);
        return im2col.outputArguments()[0];
    }

    @Deprecated
    public static int outSize(int size, int k, int s, int p, int dilation, boolean coverAll) {
        k = Convolution.effectiveKernelSize(k, dilation);
        if (coverAll) {
            return (size + p * 2 - k + s - 1) / s + 1;
        }
        return (size + p * 2 - k) / s + 1;
    }

    public static int outputSize(int size, int k, int s, int p, int dilation, boolean isSameMode) {
        k = Convolution.effectiveKernelSize(k, dilation);
        if (isSameMode) {
            return (int)Math.ceil((float)size * 1.0f / (float)s);
        }
        return (size - k + 2 * p) / s + 1;
    }

    public static int effectiveKernelSize(int kernel, int dilation) {
        return kernel + (kernel - 1) * (dilation - 1);
    }

    public static INDArray conv2d(INDArray input, INDArray kernel, Type type) {
        return Nd4j.getConvolution().conv2d(input, kernel, type);
    }

    public static INDArray conv2d(IComplexNDArray input, IComplexNDArray kernel, Type type) {
        return Nd4j.getConvolution().conv2d(input, kernel, type);
    }

    public static INDArray convn(INDArray input, INDArray kernel, Type type, int[] axes) {
        return Nd4j.getConvolution().convn(input, kernel, type, axes);
    }

    public static IComplexNDArray convn(IComplexNDArray input, IComplexNDArray kernel, Type type, int[] axes) {
        return Nd4j.getConvolution().convn(input, kernel, type, axes);
    }

    public static INDArray convn(INDArray input, INDArray kernel, Type type) {
        return Nd4j.getConvolution().convn(input, kernel, type);
    }

    public static IComplexNDArray convn(IComplexNDArray input, IComplexNDArray kernel, Type type) {
        return Nd4j.getConvolution().convn(input, kernel, type);
    }

    public static enum Type {
        FULL,
        VALID,
        SAME;

    }
}

