/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDArrayIndexer;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.util.Preconditions;
import java.util.Arrays;
import java.util.List;

class MxNDArrayEx
implements NDArrayEx {
    private MxNDArray array;

    MxNDArrayEx(MxNDArray parent) {
        this.array = parent;
    }

    private Shape deriveBroadcastedShape(Shape lhs, Shape rhs) {
        long[] result = new long[Math.max(lhs.dimension(), rhs.dimension())];
        long lDiff = result.length - lhs.dimension();
        long rDiff = result.length - rhs.dimension();
        for (int i = 0; i < result.length; ++i) {
            long l = 1L;
            long r = 1L;
            if ((long)i >= lDiff) {
                l = lhs.get(Math.toIntExact((long)i - lDiff));
            }
            if ((long)i >= rDiff) {
                r = rhs.get(Math.toIntExact((long)i - rDiff));
            }
            if (l != r) {
                if (l != 1L && r != 1L) {
                    throw new IllegalArgumentException("operands could not be broadcast together with shapes " + lhs + " " + rhs);
                }
                result[i] = l == 1L ? r : l;
                continue;
            }
            result[i] = l;
        }
        return new Shape(result);
    }

    public NDArray rdiv(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.getManager().invoke("_rdiv_scalar", (NDArray)this.array, params);
    }

    public NDArray rdiv(NDArray b) {
        return b.div((NDArray)this.array);
    }

    public NDArray rdivi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.getManager().invoke("_rdiv_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, params);
        return this.array;
    }

    public NDArray rdivi(NDArray b) {
        this.getManager().invoke("elemwise_div", new NDArray[]{b, this.array}, new NDArray[]{this.array}, null);
        return this.array;
    }

    public NDArray rsub(Number n) {
        return this.array.sub(n).neg();
    }

    public NDArray rsub(NDArray b) {
        return this.array.sub(b).neg();
    }

    public NDArray rsubi(Number n) {
        return this.array.subi(n).negi();
    }

    public NDArray rsubi(NDArray b) {
        return this.array.subi(b).negi();
    }

    public NDArray rmod(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.getManager().invoke("_npi_rmod_scalar", (NDArray)this.array, params);
    }

    public NDArray rmod(NDArray b) {
        return b.mod((NDArray)this.array);
    }

    public NDArray rmodi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.getManager().invoke("_npi_rmod_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, params);
        return this.array;
    }

    public NDArray rmodi(NDArray b) {
        this.getManager().invoke("_npi_mod", new NDArray[]{b, this.array}, new NDArray[]{this.array}, null);
        return this.array;
    }

    public NDArray rpow(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.getManager().invoke("_npi_rpower_scalar", (NDArray)this.array, params);
    }

    public NDArray rpowi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.getManager().invoke("_npi_rpower_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, params);
        return this.array;
    }

    public NDArray relu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "relu");
        return this.getManager().invoke("_npx_activation", (NDArray)this.array, params);
    }

    public NDArray sigmoid() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "sigmoid");
        return this.getManager().invoke("_npx_activation", (NDArray)this.array, params);
    }

    public NDArray tanh() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "tanh");
        return this.getManager().invoke("_npx_activation", (NDArray)this.array, params);
    }

    public NDArray softPlus() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "softrelu");
        return this.getManager().invoke("_npx_activation", (NDArray)this.array, params);
    }

    public NDArray softSign() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "softsign");
        return this.getManager().invoke("_npx_activation", (NDArray)this.array, params);
    }

    public NDArray leakyRelu(float alpha) {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "leaky");
        params.addParam("slope", alpha);
        return this.getManager().invoke("_npx_leaky_relu", (NDArray)this.array, params);
    }

    public NDArray elu(float alpha) {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "elu");
        params.addParam("slope", alpha);
        return this.getManager().invoke("_npx_leaky_relu", (NDArray)this.array, params);
    }

    public NDArray selu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "selu");
        return this.getManager().invoke("_npx_leaky_relu", (NDArray)this.array, params);
    }

    public NDArray gelu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "gelu");
        return this.getManager().invoke("_npx_leaky_relu", (NDArray)this.array, params);
    }

    public NDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernelShape);
        params.add("pool_type", "max");
        params.addParam("stride", stride);
        params.addParam("pad", padding);
        params.add("pooling_convention", ceilMode ? "full" : "valid");
        return this.getManager().invoke("_npx_pooling", this.getArray(), params);
    }

    public NDArray globalMaxPool() {
        MxOpParams params = new MxOpParams();
        params.add("kernel", this.getGlobalPoolingShapes(1L));
        params.add("pad", this.getGlobalPoolingShapes(0L));
        params.add("pool_type", "max");
        params.addParam("global_pool", true);
        try (NDArray temp = this.getManager().invoke("_npx_pooling", this.getArray(), params);){
            NDArray nDArray = temp.reshape(new long[]{temp.getShape().size(new int[]{0}), temp.getShape().size(new int[]{1})});
            return nDArray;
        }
    }

    public NDArray avgPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernelShape);
        params.add("pool_type", "avg");
        params.addParam("stride", stride);
        params.addParam("pad", padding);
        params.add("pooling_convention", ceilMode ? "full" : "valid");
        params.addParam("count_include_pad", countIncludePad);
        return this.getManager().invoke("_npx_pooling", this.getArray(), params);
    }

    public NDArray globalAvgPool() {
        MxOpParams params = new MxOpParams();
        params.add("kernel", this.getGlobalPoolingShapes(1L));
        params.add("pad", this.getGlobalPoolingShapes(0L));
        params.add("pool_type", "avg");
        params.addParam("global_pool", true);
        try (NDArray temp = this.getManager().invoke("_npx_pooling", this.getArray(), params);){
            NDArray nDArray = temp.reshape(new long[]{temp.getShape().size(new int[]{0}), temp.getShape().size(new int[]{1})});
            return nDArray;
        }
    }

    public NDArray lpPool(float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        if ((float)((int)normType) != normType) {
            throw new IllegalArgumentException("float type of normType is not supported in MXNet engine, please use integer instead");
        }
        MxOpParams params = new MxOpParams();
        params.addParam("p_value", (int)normType);
        params.addParam("kernel", kernelShape);
        params.add("pool_type", "lp");
        params.addParam("stride", stride);
        params.addParam("pad", padding);
        params.add("pooling_convention", ceilMode ? "full" : "valid");
        return this.getManager().invoke("_npx_pooling", this.getArray(), params);
    }

    public NDArray globalLpPool(float normType) {
        if ((float)((int)normType) != normType) {
            throw new IllegalArgumentException("float type of normType is not supported in MXNet engine, please use integer instead");
        }
        MxOpParams params = new MxOpParams();
        params.add("pool_type", "lp");
        params.addParam("p_value", (int)normType);
        params.addParam("global_pool", true);
        try (NDArray temp = this.getManager().invoke("_npx_pooling", this.getArray(), params);){
            NDArray nDArray = temp.reshape(new long[]{temp.getShape().size(new int[]{0}), temp.getShape().size(new int[]{1})});
            return nDArray;
        }
    }

    public void adadeltaUpdate(NDList inputs, NDList weights, float weightDecay, float rescaleGrad, float clipGrad, float rho, float epsilon) {
        NDArray weight = (NDArray)inputs.get(0);
        NDArray grad = (NDArray)inputs.get(1);
        NDArray s = (NDArray)inputs.get(2);
        NDArray delta = (NDArray)inputs.get(3);
        try (NDManager subManager = NDManager.newBaseManager();){
            subManager.tempAttachAll(new NDResource[]{inputs, weights});
            grad.muli((Number)Float.valueOf(rescaleGrad));
            if (clipGrad > 0.0f) {
                grad = grad.clip((Number)Float.valueOf(-clipGrad), (Number)Float.valueOf(clipGrad));
            }
            grad.addi(weight.mul((Number)Float.valueOf(weightDecay)));
            s.muli((Number)Float.valueOf(rho)).addi(grad.square().mul((Number)Float.valueOf(1.0f - rho)));
            NDArray g = delta.add((Number)Float.valueOf(epsilon)).sqrt().div(s.add((Number)Float.valueOf(epsilon)).sqrt()).mul(grad);
            delta.muli((Number)Float.valueOf(rho)).addi(g.square().mul((Number)Float.valueOf(1.0f - rho)));
            weight.subi(g);
        }
    }

    public void adagradUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float epsilon) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("epsilon", epsilon);
        this.getManager().invoke("adagrad_update", inputs, weights, params);
    }

    public void adamUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float beta1, float beta2, float epsilon, boolean lazyUpdate) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("beta1", beta1);
        params.addParam("beta2", beta2);
        params.addParam("epsilon", epsilon);
        params.addParam("lazy_update", lazyUpdate);
        this.getManager().invoke("adam_update", inputs, weights, params);
    }

    public void rmspropUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float gamma1, float gamma2, float epsilon, boolean centered) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("gamma1", gamma1);
        params.addParam("epsilon", epsilon);
        if (!centered) {
            this.getManager().invoke("rmsprop_update", inputs, weights, params);
        } else {
            params.addParam("gamma2", gamma2);
            this.getManager().invoke("rmspropalex_update", inputs, weights, params);
        }
    }

    public void nagUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("momentum", momentum);
        this.getManager().invoke("nag_mom_update", inputs, weights, params);
    }

    public void sgdUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum, boolean lazyUpdate) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("lazy_update", lazyUpdate);
        if (momentum != 0.0f) {
            params.addParam("momentum", momentum);
            this.getManager().invoke("sgd_mom_update", inputs, weights, params);
        } else {
            this.getManager().invoke("sgd_update", inputs, weights, params);
        }
    }

    public NDList convolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation, int groups) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", weight.getShape().slice(2));
        params.addParam("stride", stride);
        params.addParam("pad", padding);
        params.addParam("dilate", dilation);
        params.addParam("num_group", groups);
        params.addParam("num_filter", weight.getShape().get(0));
        NDList inputs = new NDList(new NDArray[]{input, weight});
        if (bias != null) {
            params.add("no_bias", false);
            inputs.add((Object)bias);
        } else {
            params.add("no_bias", true);
        }
        return this.getManager().invoke("_npx_convolution", inputs, params);
    }

    public NDList deconvolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation, int groups) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", weight.getShape().slice(2));
        params.addParam("stride", stride);
        params.addParam("pad", padding);
        params.addParam("adj", outPadding);
        params.addParam("dilate", dilation);
        params.addParam("num_group", groups);
        params.addParam("num_filter", weight.getShape().get(0));
        NDList inputs = new NDList(new NDArray[]{input, weight});
        if (bias != null) {
            params.add("no_bias", false);
            inputs.add((Object)bias);
        } else {
            params.add("no_bias", true);
        }
        return this.getManager().invoke("_npx_deconvolution", inputs, params);
    }

    public NDList linear(NDArray input, NDArray weight, NDArray bias) {
        MxOpParams params = new MxOpParams();
        params.addParam("num_hidden", weight.size(0));
        params.addParam("flatten", false);
        params.addParam("no_bias", bias == null);
        NDList inputs = new NDList(new NDArray[]{input, weight});
        if (bias != null) {
            inputs.add((Object)bias);
        }
        return this.getManager().invoke("_npx_fully_connected", inputs, params);
    }

    public NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) {
        if (!sparse.equals((Object)SparseFormat.DENSE) && !sparse.equals((Object)SparseFormat.ROW_SPARSE)) {
            throw new IllegalArgumentException("MXNet only supports row sparse");
        }
        MxOpParams params = new MxOpParams();
        long inputDim = weight.getShape().get(0);
        long outputDim = weight.getShape().get(1);
        params.addParam("input_dim", inputDim);
        params.addParam("output_dim", outputDim);
        params.addParam("sparse_grad", sparse.getValue());
        return this.getManager().invoke("_npx_embedding", new NDList(new NDArray[]{input, weight}), params);
    }

    public NDList prelu(NDArray input, NDArray alpha) {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "prelu");
        return this.getManager().invoke("_npx_leaky_relu", new NDList(new NDArray[]{input, alpha}), params);
    }

    public NDList dropout(NDArray input, float rate, boolean training) {
        if (training != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of dropout in MXNet should align with the mode of GradientCollector");
        }
        MxOpParams params = new MxOpParams();
        params.addParam("p", rate);
        return this.getManager().invoke("_npx_dropout", new NDList(new NDArray[]{input}), params);
    }

    public NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", -1);
        params.addParam("eps", eps);
        NDArray reshapedInput = input.reshape(input.getShape().slice(0, Math.toIntExact(input.getShape().dimension() - normalizedShape.dimension())).add(new long[]{normalizedShape.size()}));
        return new NDList(new NDArray[]{((NDArray)this.getManager().invoke("_npx_layer_norm", new NDList(new NDArray[]{reshapedInput, gamma.reshape(new long[]{normalizedShape.size()}), beta.reshape(new long[]{normalizedShape.size()})}), params).get(0)).reshape(input.getShape())});
    }

    public NDList batchNorm(NDArray input, NDArray runningMean, NDArray runningVar, NDArray gamma, NDArray beta, int axis, float momentum, float eps, boolean training) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        params.addParam("fix_gamma", gamma == null);
        params.addParam("eps", eps);
        params.addParam("momentum", momentum);
        if (training != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of batchNorm in MXNet should align with the mode of GradientCollector");
        }
        return this.getManager().invoke("_npx_batch_norm", new NDList(new NDArray[]{input, gamma, beta, runningMean, runningVar}), params);
    }

    public NDList rnn(NDArray input, NDArray state, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        int numParams = numLayers * (hasBiases ? 4 : 2) * (bidirectional ? 2 : 1);
        Preconditions.checkArgument((params.size() == numParams ? 1 : 0) != 0, (String)("The size of Params is incorrect expect " + numParams + " parameters but got " + params.size()));
        if (training != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of rnn in MXNet should align with the mode of GradientCollector");
        }
        if (batchFirst) {
            input = input.swapAxes(0, 1);
        }
        MxOpParams opParams = new MxOpParams();
        opParams.addParam("p", dropRate);
        opParams.addParam("state_size", state.getShape().tail());
        opParams.addParam("num_layers", numLayers);
        opParams.addParam("bidirectional", bidirectional);
        opParams.addParam("state_outputs", true);
        opParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" : "rnn_relu");
        NDList inputs = new NDList();
        inputs.add((Object)input);
        try (NDList temp = new NDList();){
            for (NDArray param : params) {
                temp.add((Object)param.flatten());
            }
            NDArray tempParam = NDArrays.concat((NDList)temp);
            tempParam.attach(input.getManager());
            inputs.add((Object)tempParam);
        }
        inputs.add((Object)state);
        if (!batchFirst) {
            return this.getManager().invoke("_npx_rnn", inputs, opParams);
        }
        NDList result = this.getManager().invoke("_npx_rnn", inputs, opParams);
        NDArray temp = result.head();
        Object object = null;
        try {
            NDList nDList = new NDList(new NDArray[]{temp.swapAxes(0, 1), (NDArray)result.get(1)});
            return nDList;
        }
        catch (Throwable throwable) {
            object = throwable;
            throw throwable;
        }
        finally {
            if (temp != null) {
                if (object != null) {
                    try {
                        temp.close();
                    }
                    catch (Throwable throwable) {
                        ((Throwable)object).addSuppressed(throwable);
                    }
                } else {
                    temp.close();
                }
            }
        }
    }

    public NDList gru(NDArray input, NDArray state, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        int numParams = numLayers * (hasBiases ? 4 : 2) * (bidirectional ? 2 : 1);
        Preconditions.checkArgument((params.size() == numParams ? 1 : 0) != 0, (String)("The size of Params is incorrect expect " + numParams + " parameters but got " + params.size()));
        if (training != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of gru in MXNet should align with the mode of GradientCollector");
        }
        if (batchFirst) {
            input = input.swapAxes(0, 1);
        }
        MxOpParams opParams = new MxOpParams();
        opParams.addParam("p", dropRate);
        opParams.addParam("state_size", state.getShape().tail());
        opParams.addParam("num_layers", numLayers);
        opParams.addParam("bidirectional", bidirectional);
        opParams.addParam("state_outputs", true);
        opParams.addParam("mode", "gru");
        NDList inputs = new NDList();
        inputs.add((Object)input);
        try (NDList temp = new NDList();){
            for (NDArray param : params) {
                temp.add((Object)param.flatten());
            }
            NDArray tempParam = NDArrays.concat((NDList)temp);
            tempParam.attach(input.getManager());
            inputs.add((Object)tempParam);
        }
        inputs.add((Object)state);
        if (!batchFirst) {
            return this.getManager().invoke("_npx_rnn", inputs, opParams);
        }
        NDList result = this.getManager().invoke("_npx_rnn", inputs, opParams);
        NDArray temp = result.head();
        Object object = null;
        try {
            NDList nDList = new NDList(new NDArray[]{temp.swapAxes(0, 1), (NDArray)result.get(1)});
            return nDList;
        }
        catch (Throwable throwable) {
            object = throwable;
            throw throwable;
        }
        finally {
            if (temp != null) {
                if (object != null) {
                    try {
                        temp.close();
                    }
                    catch (Throwable throwable) {
                        ((Throwable)object).addSuppressed(throwable);
                    }
                } else {
                    temp.close();
                }
            }
        }
    }

    public NDList lstm(NDArray input, NDList states, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        int numParams = numLayers * (hasBiases ? 4 : 2) * (bidirectional ? 2 : 1);
        Preconditions.checkArgument((params.size() == numParams ? 1 : 0) != 0, (String)("The size of Params is incorrect expect " + numParams + " parameters but got " + params.size()));
        if (training != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of lstm in MXNet should align with the mode of GradientCollector");
        }
        if (batchFirst) {
            input = input.swapAxes(0, 1);
        }
        MxOpParams opParams = new MxOpParams();
        opParams.addParam("mode", "lstm");
        opParams.addParam("p", dropRate);
        opParams.addParam("state_size", states.head().getShape().tail());
        opParams.addParam("state_outputs", true);
        opParams.addParam("num_layers", numLayers);
        opParams.addParam("bidirectional", bidirectional);
        opParams.addParam("lstm_state_clip_nan", true);
        NDList inputs = new NDList();
        inputs.add((Object)input);
        try (NDList temp = new NDList();){
            for (NDArray param : params) {
                temp.add((Object)param.flatten());
            }
            NDArray tempParam = NDArrays.concat((NDList)temp);
            tempParam.attach(input.getManager());
            inputs.add((Object)tempParam);
        }
        inputs.addAll(states);
        if (!batchFirst) {
            return this.getManager().invoke("_npx_rnn", inputs, opParams);
        }
        NDList result = this.getManager().invoke("_npx_rnn", inputs, opParams);
        NDArray temp = result.head();
        Object object = null;
        try {
            NDList nDList = new NDList(new NDArray[]{temp.swapAxes(0, 1), (NDArray)result.get(1), (NDArray)result.get(2)});
            return nDList;
        }
        catch (Throwable throwable) {
            object = throwable;
            throw throwable;
        }
        finally {
            if (temp != null) {
                if (object != null) {
                    try {
                        temp.close();
                    }
                    catch (Throwable throwable) {
                        ((Throwable)object).addSuppressed(throwable);
                    }
                } else {
                    temp.close();
                }
            }
        }
    }

    public NDArray normalize(float[] mean, float[] std) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("mean", mean);
        params.addTupleParam("std", std);
        return this.getManager().invoke("_npx__image_normalize", (NDArray)this.array, params);
    }

    public NDArray toTensor() {
        return this.getManager().invoke("_npx__image_to_tensor", (NDArray)this.array, null);
    }

    public NDArray resize(int width, int height, int interpolation) {
        if (this.array.isEmpty()) {
            throw new IllegalArgumentException("attempt to resize of an empty NDArray");
        }
        MxOpParams params = new MxOpParams();
        params.addTupleParam("size", width, height);
        params.addParam("interp", interpolation);
        return this.getManager().invoke("_npx__image_resize", (NDArray)this.array, params);
    }

    public NDArray crop(int x, int y, int width, int height) {
        MxOpParams params = new MxOpParams();
        params.add("x", x);
        params.add("y", y);
        params.add("width", width);
        params.add("height", height);
        return this.getManager().invoke("_npx__image_crop", (NDArray)this.array, params);
    }

    public NDArray randomFlipLeftRight() {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomFlipLeftRight is not supported on GPU");
        }
        return this.getManager().invoke("_npx__image_random_flip_left_right", (NDArray)this.array, null);
    }

    public NDArray randomFlipTopBottom() {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomFlipTopBottom is not supported on GPU");
        }
        return this.getManager().invoke("_npx__image_random_flip_top_bottom", (NDArray)this.array, null);
    }

    public NDArray randomBrightness(float brightness) {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomBrightness is not supported on GPU");
        }
        MxOpParams params = new MxOpParams();
        float min = Math.max(0.0f, 1.0f - brightness);
        float max = 1.0f + brightness;
        params.addParam("min_factor", min);
        params.addParam("max_factor", max);
        return this.getManager().invoke("_npx__image_random_brightness", (NDArray)this.array, params);
    }

    public NDArray randomHue(float hue) {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomHue is not supported on GPU");
        }
        MxOpParams params = new MxOpParams();
        float min = Math.max(0.0f, 1.0f - hue);
        float max = 1.0f + hue;
        params.addParam("min_factor", min);
        params.addParam("max_factor", max);
        return this.getManager().invoke("_npx__image_random_hue", (NDArray)this.array, params);
    }

    public NDArray randomColorJitter(float brightness, float contrast, float saturation, float hue) {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomColorJitter is not supported on GPU");
        }
        MxOpParams params = new MxOpParams();
        params.addParam("brightness", brightness);
        params.addParam("contrast", contrast);
        params.addParam("saturation", saturation);
        params.addParam("hue", hue);
        return this.getManager().invoke("_npx__image_random_color_jitter", (NDArray)this.array, params);
    }

    public NDArrayIndexer getIndexer() {
        return new MxNDArrayIndexer(this.array.getManager());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NDArray where(NDArray condition, NDArray other) {
        NDArray array2;
        MxNDArray array1;
        NDArray nDArray = condition = condition.getDataType() == DataType.BOOLEAN ? condition.toType(DataType.INT32, false) : condition;
        if (this.array.getDataType() != other.getDataType()) {
            throw new IllegalArgumentException("DataType mismatch, required " + this.array.getDataType() + " actual " + other.getDataType());
        }
        if (!this.array.shapeEquals(other)) {
            Shape res = this.deriveBroadcastedShape(this.array.getShape(), other.getShape());
            array1 = !res.equals((Object)this.array.getShape()) ? this.array.broadcast(res) : this.array;
            array2 = !res.equals((Object)other.getShape()) ? other.broadcast(res) : other;
        } else {
            array1 = this.array;
            array2 = other;
        }
        try {
            NDArray nDArray2 = this.getManager().invoke("where", new NDArray[]{condition, array1, array2}, null);
            return nDArray2;
        }
        finally {
            if (array1 != this.array) {
                array1.close();
            }
            if (array2 != other) {
                array2.close();
            }
        }
    }

    public NDArray stack(NDList arrays, int axis) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        NDArray[] srcArray = new NDArray[arrays.size() + 1];
        srcArray[0] = this.array;
        System.arraycopy(arrays.toArray((Object[])new NDArray[0]), 0, srcArray, 1, arrays.size());
        return this.getManager().invoke("_npi_stack", srcArray, params);
    }

    public NDArray concat(NDList list, int axis) {
        NDUtils.checkConcatInput((NDList)list);
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        NDArray[] srcArray = new NDArray[list.size() + 1];
        srcArray[0] = this.array;
        System.arraycopy(list.toArray((Object[])new NDArray[0]), 0, srcArray, 1, list.size());
        return this.getManager().invoke("_npi_concatenate", srcArray, params);
    }

    public NDList multiBoxTarget(NDList inputs, float iouThreshold, float ignoreLabel, float negativeMiningRatio, float negativeMiningThreshold, int minNegativeSamples) {
        MxOpParams parameters = new MxOpParams();
        parameters.add("minimum_negative_samples", minNegativeSamples);
        parameters.add("overlap_threshold", Float.valueOf(iouThreshold));
        parameters.add("ignore_label", Float.valueOf(ignoreLabel));
        parameters.add("negative_mining_ratio", Float.valueOf(negativeMiningRatio));
        parameters.add("negative_mining_thresh", Float.valueOf(negativeMiningThreshold));
        return this.getManager().invoke("MultiBoxTarget", inputs, parameters);
    }

    public NDList multiBoxPrior(List<Float> sizes, List<Float> ratios, List<Float> steps, List<Float> offsets, boolean clip) {
        MxOpParams parameters = new MxOpParams();
        parameters.add("sizes", sizes);
        parameters.add("ratios", ratios);
        parameters.add("steps", steps);
        parameters.add("offsets", offsets);
        parameters.add("clip", clip);
        return this.getManager().invoke("MultiBoxPrior", new NDList(new NDArray[]{this.array}), parameters);
    }

    public NDList multiBoxDetection(NDList inputs, boolean clip, float threshold, int backgroundId, float nmsThreashold, boolean forceSuppress, int nmsTopK) {
        MxOpParams parameters = new MxOpParams();
        parameters.add("clip", clip);
        parameters.add("threshold", Float.valueOf(threshold));
        parameters.add("background_id", backgroundId);
        parameters.add("nms_threshold", Float.valueOf(nmsThreashold));
        parameters.add("force_suppress", forceSuppress);
        parameters.add("nms_topk", nmsTopK);
        return this.getManager().invoke("MultiBoxDetection", inputs, parameters);
    }

    public NDArray getArray() {
        return this.array;
    }

    private MxNDManager getManager() {
        return this.array.getManager();
    }

    private int getGlobalPoolingDim() {
        int poolDim = this.getArray().getShape().dimension() - 2;
        if (poolDim < 1 || poolDim > 3) {
            throw new IllegalStateException("GlobalPooling only support1 to 3 Dimensions, " + poolDim + "D is not supported.");
        }
        return poolDim;
    }

    private Shape getGlobalPoolingShapes(long fillValue) {
        int poolDim = this.getGlobalPoolingDim();
        long[] shape = new long[poolDim];
        Arrays.fill(shape, fillValue);
        return new Shape(shape);
    }
}

