/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import java.util.Stack;

public class MxNDArrayIndexer
extends NDArrayIndexer {
    private MxNDManager manager;

    MxNDArrayIndexer(MxNDManager manager) {
        this.manager = manager;
    }

    public NDArray get(NDArray array, NDIndexFullPick fullPick) {
        array = this.manager.from((NDArray)array);
        MxOpParams params = new MxOpParams();
        params.addParam("axis", fullPick.getAxis());
        params.addParam("keepdims", true);
        params.add("mode", "wrap");
        return this.manager.invoke("pick", new NDList(new NDArray[]{array, fullPick.getIndices()}), params).singletonOrThrow();
    }

    public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
        array = this.manager.from((NDArray)array);
        MxOpParams params = new MxOpParams();
        params.addTupleParam("begin", fullSlice.getMin());
        params.addTupleParam("end", fullSlice.getMax());
        params.addTupleParam("step", fullSlice.getStep());
        NDArray result = this.manager.invoke("_npi_slice", (NDArray)array, (PairList<String, ?>)params);
        int[] toSqueeze = fullSlice.getToSqueeze();
        if (toSqueeze.length > 0) {
            NDArray oldResult = result;
            result = result.squeeze(toSqueeze);
            oldResult.close();
        }
        return result;
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
        array = this.manager.from((NDArray)array);
        MxOpParams params = new MxOpParams();
        params.addTupleParam("begin", fullSlice.getMin());
        params.addTupleParam("end", fullSlice.getMax());
        params.addTupleParam("step", fullSlice.getStep());
        Stack<NDArray> prepareValue = new Stack<NDArray>();
        prepareValue.add(value);
        prepareValue.add(((NDArray)prepareValue.peek()).toDevice(array.getDevice(), false));
        Shape targetShape = fullSlice.getShape();
        while (targetShape.size() > value.size()) {
            targetShape = targetShape.slice(1);
        }
        prepareValue.add(((NDArray)prepareValue.peek()).reshape(targetShape));
        prepareValue.add(((NDArray)prepareValue.peek()).broadcast(fullSlice.getShape()));
        this.manager.invoke("_npi_slice_assign", new NDArray[]{array, (NDArray)prepareValue.peek()}, new NDArray[]{array}, params);
        for (NDArray toClean : prepareValue) {
            if (toClean == value) continue;
            toClean.close();
        }
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
        array = this.manager.from((NDArray)array);
        MxOpParams params = new MxOpParams();
        params.addTupleParam("begin", fullSlice.getMin());
        params.addTupleParam("end", fullSlice.getMax());
        params.addTupleParam("step", fullSlice.getStep());
        params.addParam("scalar", value);
        this.manager.invoke("_npi_slice_assign_scalar", new NDArray[]{array}, new NDArray[]{array}, params);
    }
}

