/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.engine;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.NDIndexBooleans;
import ai.djl.ndarray.index.NDIndexFullSlice;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.tensorflow.engine.TfDataType;
import ai.djl.tensorflow.engine.TfNDArrayEx;
import ai.djl.tensorflow.engine.TfNDManager;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.BroadcastTo;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.Max;
import org.tensorflow.op.core.Min;
import org.tensorflow.op.core.Prod;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.ReduceAll;
import org.tensorflow.op.core.ReduceAny;
import org.tensorflow.op.core.Slice;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.core.Sum;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Cumsum;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.math.NotEqual;
import org.tensorflow.op.nn.LogSoftmax;
import org.tensorflow.op.nn.TopK;
import org.tensorflow.op.train.BatchMatMul;
import org.tensorflow.tools.Shape;
import org.tensorflow.tools.buffer.ByteDataBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TType;

public class TfNDArray
implements NDArray {
    private static final int MAX_SIZE = 100;
    private static final int MAX_DEPTH = 10;
    private static final int MAX_ROWS = 10;
    private static final int MAX_COLUMNS = 20;
    private static final int MAX_OUTPUTS_PER_OP = 8;
    private String uid = UUID.randomUUID().toString();
    private Tensor<?> tensor;
    private ai.djl.ndarray.types.Shape shape;
    private TfNDManager manager;
    private Ops tf;
    private Operand<?> operand;
    private String name;
    private TfNDArrayEx tfNDArrayEx;

    TfNDArray(NDManager manager, Tensor<?> tensor) {
        this.manager = (TfNDManager)manager;
        this.manager.attach(this.getUid(), (AutoCloseable)((Object)this));
        this.tensor = tensor;
        this.shape = new ai.djl.ndarray.types.Shape(tensor.shape().asArray());
        this.tf = this.manager.getTf();
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    TfNDArray(NDManager manager, Operand<?> out) {
        this.manager = (TfNDManager)manager;
        this.manager.attach(this.getUid(), (AutoCloseable)((Object)this));
        this.tensor = out.asOutput().tensor();
        this.shape = new ai.djl.ndarray.types.Shape(this.tensor.shape().asArray());
        this.tf = this.manager.getTf();
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    public TfNDArray(NDManager manager, ai.djl.ndarray.types.Shape shape, FloatBuffer data) {
        this.manager = (TfNDManager)manager;
        this.manager.attach(this.getUid(), (AutoCloseable)((Object)this));
        this.tensor = Tensor.of((org.tensorflow.DataType)TFloat32.DTYPE, (Shape)TfNDArray.toTfShape(shape), (ByteDataBuffer)TfNDArray.toDataBuffer(data));
        this.shape = shape;
        this.tf = this.manager.getTf();
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    TfNDArray(NDManager manager, ai.djl.ndarray.types.Shape shape, ByteBuffer data) {
        this.manager = (TfNDManager)manager;
        this.manager.attach(this.getUid(), (AutoCloseable)((Object)this));
        this.shape = shape;
        this.tf = this.manager.getTf();
        this.tensor = Tensor.of((org.tensorflow.DataType)TUint8.DTYPE, (Shape)TfNDArray.toTfShape(shape), (ByteDataBuffer)DataBuffers.of((ByteBuffer)data));
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    public NDManager getManager() {
        return this.manager;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public final String getUid() {
        return this.uid;
    }

    public DataType getDataType() {
        return TfDataType.fromTf(this.getTfDataType());
    }

    public Device getDevice() {
        return this.manager.getDevice();
    }

    public ai.djl.ndarray.types.Shape getShape() {
        if (this.shape == null) {
            this.shape = new ai.djl.ndarray.types.Shape(this.tensor.shape().asArray());
        }
        return this.shape;
    }

    public org.tensorflow.DataType<? extends TType> getTfDataType() {
        return this.tensor.dataType();
    }

    public SparseFormat getSparseFormat() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public boolean isSparse() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray toDevice(Device device, boolean copy) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray toType(DataType dataType, boolean copy) {
        Cast output = this.tf.dtypes.cast(this.asOperand(), TfDataType.toTf(dataType), new Cast.Options[0]);
        if (copy) {
            output = this.tf.deepCopy((Operand)output);
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)output);
    }

    public void attachGradient() {
    }

    public void attachGradient(SparseFormat sparseFormat) {
    }

    public NDArray getGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public double[] toDoubleArray() {
        double[] result = new double[(int)this.getShape().size()];
        this.tensor.rawData().asDoubles().read(result);
        return result;
    }

    public float[] toFloatArray() {
        float[] result = new float[(int)this.getShape().size()];
        this.tensor.rawData().asFloats().read(result);
        return result;
    }

    public int[] toIntArray() {
        int[] result = new int[(int)this.getShape().size()];
        this.tensor.rawData().asInts().read(result);
        return result;
    }

    public long[] toLongArray() {
        long[] result = new long[(int)this.getShape().size()];
        this.tensor.rawData().asLongs().read(result);
        return result;
    }

    public boolean[] toBooleanArray() {
        boolean[] result = new boolean[(int)this.getShape().size()];
        this.tensor.rawData().asBooleans().read(result);
        return result;
    }

    public ByteBuffer toByteBuffer() {
        ai.djl.ndarray.types.Shape sh = this.getShape();
        DataType dType = this.getDataType();
        long product = sh.size();
        long len = (long)dType.getNumOfBytes() * product;
        byte[] buf = new byte[Math.toIntExact(len)];
        this.tensor.rawData().read(buf);
        return ByteBuffer.wrap(buf);
    }

    public void set(Buffer data) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public void set(NDIndex index, NDArray value) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public void set(NDIndex index, Number value) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public void setScalar(NDIndex index, Number value) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public NDArray get(NDIndex index) {
        if (index.getRank() == 0 && this.getShape().isScalar()) {
            return this;
        }
        List indices = index.getIndices();
        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException("get() currently didn't support more that one boolean NDArray");
            }
            return this.booleanMask(((NDIndexBooleans)indices.get(0)).getIndex());
        }
        NDIndexFullSlice fullSlice = index.getAsFullSlice(this.getShape()).orElse(null);
        if (fullSlice != null) {
            long[] begin = fullSlice.getMin();
            long[] end = fullSlice.getMax();
            long[] size = new long[begin.length];
            Arrays.setAll(size, i -> end[i] - begin[i]);
            Slice sliced = this.tf.slice(this.asOperand(), (Operand)this.tf.constant(begin), (Operand)this.tf.constant(size));
            if (!fullSlice.getToSqueeze().isEmpty()) {
                sliced = this.tf.squeeze((Operand)sliced, new Squeeze.Options[]{Squeeze.axis(fullSlice.getToSqueeze().stream().map(Integer::longValue).collect(Collectors.toList()))});
            }
            return new TfNDArray((NDManager)this.manager, (Operand<?>)sliced);
        }
        throw new UnsupportedOperationException("get() currently supports all, fixed, and slices indices");
    }

    public void copyTo(NDArray ndArray) {
        if (!(ndArray instanceof TfNDArray)) {
            throw new IllegalArgumentException("Only TfNDArray is supported.");
        }
        ai.djl.ndarray.types.Shape inShape = this.getShape();
        ai.djl.ndarray.types.Shape destShape = ndArray.getShape();
        if (!Arrays.equals(inShape.getShape(), destShape.getShape())) {
            throw new IllegalArgumentException("shape are diff. Required: " + destShape + ", Actual " + inShape);
        }
        ((TfNDArray)ndArray).tensor = this.tf.deepCopy(this.asOperand()).asOutput().tensor();
        ((TfNDArray)ndArray).operand = null;
        ((TfNDArray)ndArray).shape = new ai.djl.ndarray.types.Shape(this.tensor.shape().asArray());
    }

    public NDArray booleanMask(NDArray index, int axis) {
        if (this.isScalar()) {
            if (!index.isScalar()) {
                throw new IllegalArgumentException("Input is scalar, index must also be scalar.");
            }
            if (index.toBooleanArray()[0]) {
                return this.expandDims(0);
            }
            return this.manager.create(new ai.djl.ndarray.types.Shape(new long[0]));
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.gather(this.asOperand(), (Operand)this.tf.squeeze((Operand)this.tf.where(((TfNDArray)index).asOperand()), new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(1L))}), (Operand)this.tf.constant(axis), new Gather.Options[0]));
    }

    public NDArray sequenceMask(NDArray sequenceLength, float value) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public NDArray sequenceMask(NDArray sequenceLength) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public NDArray zerosLike() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.zerosLike(this.asOperand()));
    }

    public NDArray onesLike() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.onesLike(this.asOperand()));
    }

    public boolean contentEquals(Number number) {
        if (number == null) {
            return false;
        }
        try (NDArray result = this.eq(number);){
            boolean bl = result.all().getBoolean(new long[0]);
            return bl;
        }
    }

    public boolean contentEquals(NDArray other) {
        if (other == null || !this.shapeEquals(other)) {
            return false;
        }
        if (this.getDataType() != other.getDataType()) {
            return false;
        }
        TfNDArray eq = (TfNDArray)this.eq(other);
        return eq.all().toBooleanArray()[0];
    }

    public NDArray eq(Number other) {
        return this.eq(this.manager.create(other).toType(this.getDataType(), false));
    }

    public NDArray eq(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.equal(this.asOperand(), ((TfNDArray)other).asOperand(), new Equal.Options[0]).asOutput());
    }

    public NDArray neq(Number other) {
        return this.neq(this.manager.create(other).toType(this.getDataType(), false));
    }

    public NDArray neq(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.notEqual(this.asOperand(), ((TfNDArray)other).asOperand(), new NotEqual.Options[0]).asOutput());
    }

    public NDArray gt(Number other) {
        return this.gt(this.manager.create(other).toType(this.getDataType(), false));
    }

    public NDArray gt(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.greater(this.asOperand(), ((TfNDArray)other).asOperand()).asOutput());
    }

    public NDArray gte(Number other) {
        return this.gte(this.manager.create(other).toType(this.getDataType(), false));
    }

    public NDArray gte(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.greaterEqual(this.asOperand(), ((TfNDArray)other).asOperand()).asOutput());
    }

    public NDArray lt(Number other) {
        return this.lt(this.manager.create(other).toType(this.getDataType(), false));
    }

    public NDArray lt(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.less(this.asOperand(), ((TfNDArray)other).asOperand()).asOutput());
    }

    public NDArray lte(Number other) {
        return this.lte(this.manager.create(other).toType(this.getDataType(), false));
    }

    public NDArray lte(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.lessEqual(this.asOperand(), ((TfNDArray)other).asOperand()).asOutput());
    }

    public NDArray all() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.reduceAll((Operand)this.tf.dtypes.cast(this.asOperand(), TBool.DTYPE, new Cast.Options[0]), (Operand)this.tf.range((Operand)this.tf.constant(0L), (Operand)this.tf.constant((long)this.getRank()), (Operand)this.tf.constant(1L)), new ReduceAll.Options[0]));
    }

    public NDArray any() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.reduceAny((Operand)this.tf.dtypes.cast(this.asOperand(), TBool.DTYPE, new Cast.Options[0]), (Operand)this.tf.range((Operand)this.tf.constant(0L), (Operand)this.tf.constant((long)this.getRank()), (Operand)this.tf.constant(1L)), new ReduceAny.Options[0]));
    }

    public NDArray add(Number n) {
        return this.add(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray add(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.add(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray sub(Number n) {
        return this.sub(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray sub(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.sub(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray mul(Number n) {
        return this.mul(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray mul(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.mul(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray div(Number n) {
        return this.div(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray div(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.div(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray mod(Number n) {
        return this.mod(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray mod(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.mod(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray pow(Number n) {
        return this.pow(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray pow(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.pow(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray maximum(Number n) {
        return this.maximum(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray maximum(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.maximum(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray minimum(Number n) {
        return this.minimum(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray minimum(NDArray other) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.minimum(this.asOperand(), ((TfNDArray)other).asOperand()));
    }

    public NDArray addi(Number n) {
        return this.addi(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray addi(NDArray other) {
        return this.inPlaceHelper(this.add(other), this);
    }

    public NDArray subi(Number n) {
        return this.subi(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray subi(NDArray other) {
        return this.inPlaceHelper(this.sub(other), this);
    }

    public NDArray muli(Number n) {
        return this.muli(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray muli(NDArray other) {
        return this.inPlaceHelper(this.mul(other), this);
    }

    public NDArray divi(Number n) {
        return this.divi(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray divi(NDArray other) {
        return this.inPlaceHelper(this.div(other), this);
    }

    NDArray inPlaceHelper(NDArray source, NDArray destination) {
        if (this.getShape().isScalar()) {
            throw new UnsupportedOperationException("TensorFlow engine does not support inplace operations on scalars yet");
        }
        Range indices = this.tf.range((Operand)this.tf.constant(0), (Operand)this.tf.constant((int)this.getShape().getShape()[0]), (Operand)this.tf.constant(1));
        ((TfNDArray)destination).setTensor(this.tf.inplaceUpdate(((TfNDArray)destination).asOperand(), (Operand)indices, ((TfNDArray)source).asOperand()).asOutput().tensor());
        ((TfNDArray)destination).clearOperand();
        return destination;
    }

    public NDArray toSparse(SparseFormat fmt) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray modi(Number n) {
        return this.modi(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray modi(NDArray other) {
        return this.inPlaceHelper(this.mod(other), this);
    }

    public NDArray powi(Number n) {
        return this.powi(this.manager.create(n).toType(this.getDataType(), false));
    }

    public NDArray powi(NDArray other) {
        return this.inPlaceHelper(this.pow(other), this);
    }

    NDArray rpowi(NDArray other) {
        return this.inPlaceHelper(other.pow((NDArray)this), this);
    }

    public NDArray neg() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.neg(this.asOperand()));
    }

    public NDArray negi() {
        return this.inPlaceHelper(this.neg(), this);
    }

    public NDArray abs() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.abs(this.asOperand()));
    }

    public NDArray square() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.square(this.asOperand()));
    }

    public NDArray sqrt() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.sqrt(this.asOperand()));
    }

    public NDArray cbrt() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.pow(this.asOperand(), this.toConstant(Float.valueOf(0.33333334f), this.getDataType())));
    }

    public NDArray floor() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.floor(this.asOperand()));
    }

    public NDArray ceil() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.ceil(this.asOperand()));
    }

    public NDArray round() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.round(this.asOperand()));
    }

    public NDArray trunc() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray exp() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.exp(this.asOperand()));
    }

    public NDArray log() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.log(this.asOperand()));
    }

    public NDArray log10() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.div((Operand)this.tf.math.log(this.asOperand()), (Operand)this.tf.math.log(this.toConstant(10, this.getDataType()))));
    }

    public NDArray log2() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.div((Operand)this.tf.math.log(this.asOperand()), (Operand)this.tf.math.log(this.toConstant(2, this.getDataType()))));
    }

    public NDArray sin() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.sin(this.asOperand()));
    }

    public NDArray cos() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.cos(this.asOperand()));
    }

    public NDArray tan() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.tan(this.asOperand()));
    }

    public NDArray asin() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.asin(this.asOperand()));
    }

    public NDArray acos() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.acos(this.asOperand()));
    }

    public NDArray atan() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.atan(this.asOperand()));
    }

    public NDArray sinh() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.sinh(this.asOperand()));
    }

    public NDArray cosh() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.cosh(this.asOperand()));
    }

    public NDArray tanh() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.tanh(this.asOperand()));
    }

    public NDArray asinh() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.asinh(this.asOperand()));
    }

    public NDArray acosh() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.acosh(this.asOperand()));
    }

    public NDArray atanh() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.atanh(this.asOperand()));
    }

    public NDArray toDegrees() {
        return this.mul(180).div((Number)Math.PI);
    }

    public NDArray toRadians() {
        return this.mul(Math.PI).div((Number)180);
    }

    public NDArray max() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.max(this.asOperand(), ((TfNDArray)this.manager.arange(0, this.getRank(), 1)).asOperand(), new Max.Options[0]));
    }

    public NDArray max(int[] axes, boolean keepDims) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.max(this.asOperand(), (Operand)this.tf.constant(axes), new Max.Options[]{Max.keepDims((Boolean)keepDims)}));
    }

    public NDArray min() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.min(this.asOperand(), ((TfNDArray)this.manager.arange(0, this.getRank(), 1)).asOperand(), new Min.Options[0]));
    }

    public NDArray min(int[] axes, boolean keepDims) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.min(this.asOperand(), (Operand)this.tf.constant(axes), new Min.Options[]{Min.keepDims((Boolean)keepDims)}));
    }

    public NDArray sum() {
        Cast op = this.getDataType() == DataType.BOOLEAN ? this.tf.dtypes.cast(this.asOperand(), TInt64.DTYPE, new Cast.Options[0]) : this.asOperand();
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.sum((Operand)op, (Operand)this.tf.range((Operand)this.tf.constant(0L), (Operand)this.tf.constant((long)this.getRank()), (Operand)this.tf.constant(1L)), new Sum.Options[0]));
    }

    public NDArray sum(int[] axes, boolean keepDims) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.sum(this.asOperand(), ((TfNDArray)this.manager.create(axes)).asOperand(), new Sum.Options[]{Sum.keepDims((Boolean)keepDims)}));
    }

    public NDArray prod() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.prod(this.asOperand(), (Operand)this.tf.range((Operand)this.tf.constant(0L), (Operand)this.tf.constant((long)this.getRank()), (Operand)this.tf.constant(1L)), new Prod.Options[0]));
    }

    public NDArray prod(int[] axes, boolean keepDims) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.prod(this.asOperand(), (Operand)this.tf.constant(axes), new Prod.Options[]{Prod.keepDims((Boolean)keepDims)}).asOutput());
    }

    public NDArray mean() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.mean(this.asOperand(), ((TfNDArray)this.manager.arange(0, this.getRank(), 1)).asOperand(), new Mean.Options[0]));
    }

    public NDArray mean(int[] axes, boolean keepDims) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.mean(this.asOperand(), (Operand)this.tf.constant(axes), new Mean.Options[]{Mean.keepDims((Boolean)keepDims)}).asOutput());
    }

    public NDArray trace(int offset, int axis1, int axis2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList split(long[] indices, int axis) {
        if (indices.length > 8) {
            long[] partialIndices;
            int start;
            NDList result = new NDList();
            long totalSize = this.getShape().get(axis);
            for (start = 0; start < indices.length - 8 + 2; start += 6) {
                partialIndices = new long[8];
                for (int i = 0; i < 7; ++i) {
                    partialIndices[i] = indices[start + i];
                }
                partialIndices[7] = totalSize;
                NDList splitted = this.splitHelper(partialIndices, axis);
                splitted.remove(splitted.get(splitted.size() - 1));
                if (start > 0) {
                    splitted.remove(splitted.get(0));
                }
                result.addAll(splitted);
            }
            partialIndices = new long[indices.length - start];
            for (int i = 0; i < partialIndices.length; ++i) {
                partialIndices[i] = indices[start + i];
            }
            NDList splitted = this.splitHelper(partialIndices, axis);
            splitted.remove(splitted.get(0));
            result.addAll(splitted);
            return result;
        }
        return this.splitHelper(indices, axis);
    }

    private NDList splitHelper(long[] indices, int axis) {
        long totalSize;
        NDList result = new NDList();
        ArrayList<Long> sizes = new ArrayList<Long>();
        int lastIndex = indices.length - 1;
        long dimSize = this.getShape().get(axis);
        if (indices[0] > 0L) {
            sizes.add(indices[0]);
        }
        for (int i = 1; i < indices.length; ++i) {
            sizes.add(indices[i] - indices[i - 1]);
        }
        if (indices[lastIndex] < dimSize) {
            sizes.add(dimSize - indices[lastIndex]);
        }
        if ((totalSize = sizes.stream().mapToLong(Long::longValue).sum()) != this.getShape().get(axis)) {
            throw new IllegalArgumentException("split sizes :" + totalSize + " must sum to dimension on axis " + axis + ": " + this.getShape().get(axis));
        }
        this.tf.splitV(this.asOperand(), (Operand)this.tf.constant(sizes.stream().mapToInt(Long::intValue).toArray()), (Operand)this.tf.constant(axis), Long.valueOf(sizes.size())).forEach(output -> result.add((Object)new TfNDArray((NDManager)this.manager, (Operand<?>)output)));
        return result;
    }

    public NDArray flatten() {
        return this.reshape(new ai.djl.ndarray.types.Shape(new long[]{-1L}));
    }

    public NDArray reshape(ai.djl.ndarray.types.Shape shape) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.reshape(this.asOperand(), (Operand)this.tf.constant(shape.getShape())));
    }

    public NDArray reshapeLike(NDArray array) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray expandDims(int axis) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.expandDims(this.asOperand(), (Operand)this.tf.constant(axis)));
    }

    public NDArray squeeze(int[] axes) {
        if (this.isScalar()) {
            axes = new int[]{};
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.squeeze(this.asOperand(), new Squeeze.Options[]{Squeeze.axis(Arrays.stream(axes).mapToLong(i -> i).boxed().collect(Collectors.toList()))}));
    }

    public NDArray logicalAnd(NDArray n) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.logicalAnd((Operand)this.tf.dtypes.cast(this.asOperand(), TBool.DTYPE, new Cast.Options[0]), (Operand)this.tf.dtypes.cast(((TfNDArray)n).asOperand(), TBool.DTYPE, new Cast.Options[0])));
    }

    public NDArray logicalOr(NDArray n) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.logicalOr((Operand)this.tf.dtypes.cast(this.asOperand(), TBool.DTYPE, new Cast.Options[0]), (Operand)this.tf.dtypes.cast(((TfNDArray)n).asOperand(), TBool.DTYPE, new Cast.Options[0])));
    }

    public NDArray logicalXor(NDArray n) {
        return this.logicalOr(n).logicalAnd(this.logicalAnd(n).logicalNot());
    }

    public NDArray logicalNot() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.logicalNot((Operand)this.tf.dtypes.cast(this.asOperand(), TBool.DTYPE, new Cast.Options[0])));
    }

    public NDArray argSort(int axis, boolean ascending) {
        return this.sortHelper(axis, ascending, true);
    }

    public NDArray sort(int axis) {
        return this.sortHelper(axis, true, false);
    }

    public NDArray sort() {
        return this.sortHelper(-1, true, false);
    }

    private NDArray sortHelper(int axis, boolean ascending, boolean returnIndices) {
        int k;
        Transpose input;
        NDArray transposition;
        if (this.isScalar()) {
            return this;
        }
        int rank = this.getRank();
        if (axis == -1 || axis + 1 == this.getShape().dimension()) {
            transposition = null;
            input = this.asOperand();
            long[] arrayShape = this.getShape().getShape();
            k = (int)arrayShape[arrayShape.length - 1];
        } else {
            k = (int)this.getShape().getShape()[axis];
            transposition = NDArrays.concat((NDList)new NDList(new NDArray[]{this.manager.arange(0, axis, 1, DataType.INT32, this.getDevice()), this.manager.create(new int[]{rank - 1}), this.manager.arange(axis + 1, rank - 1, 1, DataType.INT32, this.getDevice()), this.manager.create(new int[]{axis})}));
            input = this.tf.linalg.transpose(this.asOperand(), ((TfNDArray)transposition).asOperand());
        }
        TopK topK = ascending ? this.tf.nn.topK((Operand)this.tf.math.neg(input), (Operand)this.tf.constant(k), new TopK.Options[0]) : this.tf.nn.topK(input, (Operand)this.tf.constant(k), new TopK.Options[0]);
        Object result = returnIndices ? this.tf.dtypes.cast((Operand)topK.indices(), TInt64.DTYPE, new Cast.Options[0]) : topK.values();
        if (transposition != null) {
            result = this.tf.linalg.transpose((Operand)result, ((TfNDArray)transposition).asOperand());
            transposition.close();
        }
        if (ascending && !returnIndices) {
            result = this.tf.math.neg((Operand)result);
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)result);
    }

    public NDArray softmax(int[] axes, float temperature) {
        if ((double)temperature != 1.0) {
            throw new UnsupportedOperationException("TensorFlow softmax didn't suuport temperature");
        }
        return new TfNDArray((NDManager)this.manager, this.softmaxHelper(axes, false));
    }

    public NDArray logSoftmax(int[] axes, float temperature) {
        if ((double)temperature != 1.0) {
            throw new UnsupportedOperationException("TensorFlow softmax didn't suuport temperature");
        }
        return new TfNDArray((NDManager)this.manager, this.softmaxHelper(axes, true));
    }

    private Operand softmaxHelper(int[] axes, boolean logSoftmax) {
        long dim = this.getShape().dimension();
        if (axes.length > 1) {
            throw new UnsupportedOperationException("TensorFlow softmax does not support multiple axes");
        }
        if (dim == 0L) {
            return this.asOperand();
        }
        if (axes[0] == -1 || (long)axes[0] == dim - 1L) {
            return logSoftmax ? this.tf.nn.logSoftmax(this.asOperand()) : this.tf.nn.softmax(this.asOperand());
        }
        if ((long)axes[0] < -dim || (long)axes[0] >= dim) {
            throw new IllegalArgumentException("Invalid axes value: " + axes[0] + ", must be in range [" + -dim + ", " + dim + ") where " + dim + " is the number of dimensions in the input.");
        }
        ArrayList<Object> concatList = new ArrayList<Object>();
        concatList.add(this.tf.range((Operand)this.tf.constant(0L), (Operand)this.tf.constant((long)axes[0] % dim), (Operand)this.tf.constant(1L)));
        concatList.add(this.tf.expandDims((Operand)this.tf.constant(dim - 1L), (Operand)this.tf.constant(0)));
        concatList.add(this.tf.range((Operand)this.tf.constant((long)axes[0] + 1L), (Operand)this.tf.constant(dim - 1L), (Operand)this.tf.constant(1L)));
        concatList.add(this.tf.expandDims((Operand)this.tf.constant((long)axes[0]), (Operand)this.tf.constant(0)));
        Transpose transposed = this.tf.linalg.transpose(this.asOperand(), (Operand)this.tf.concat(concatList, (Operand)this.tf.constant(0)));
        LogSoftmax output = logSoftmax ? this.tf.nn.logSoftmax((Operand)transposed) : this.tf.nn.softmax((Operand)transposed);
        return this.tf.linalg.transpose((Operand)output, (Operand)this.tf.concat(concatList, (Operand)this.tf.constant(0)));
    }

    public NDArray cumSum(int axis) {
        if (this.isScalar()) {
            return this.expandDims(0);
        }
        if (Arrays.stream(this.getShape().getShape()).anyMatch(dim -> dim == 0L)) {
            return this.manager.create(new ai.djl.ndarray.types.Shape(new long[]{0L}));
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.cumsum(this.asOperand(), (Operand)this.tf.constant(axis), new Cumsum.Options[0]));
    }

    public NDArray cumSum() {
        return this.cumSum(0);
    }

    public NDArray isInfinite() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.dtypes.cast((Operand)this.tf.math.isInf(this.asOperand()), TBool.DTYPE, new Cast.Options[0]));
    }

    public NDArray isNaN() {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.dtypes.cast((Operand)this.tf.math.isNan(this.asOperand()), TBool.DTYPE, new Cast.Options[0]));
    }

    public NDArray createMask(NDIndex index) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray createMask(Predicate<Number> predicate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray tile(long repeats) {
        long[] multiples = new long[this.getShape().dimension()];
        Arrays.fill(multiples, repeats);
        return this.tile(multiples);
    }

    public NDArray tile(int axis, long repeats) {
        long[] multiples = new long[this.getShape().dimension()];
        Arrays.fill(multiples, 1L);
        multiples[axis] = repeats;
        return this.tile(multiples);
    }

    public NDArray tile(long[] repeats) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.tile(this.asOperand(), (Operand)this.tf.constant(repeats)));
    }

    public NDArray tile(ai.djl.ndarray.types.Shape desiredShape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(long repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(int axis, long repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(long[] repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(ai.djl.ndarray.types.Shape desiredShape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray dot(NDArray other) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray matMul(NDArray other) {
        if (this.isScalar() || other.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for matMul()");
        }
        if (this.getShape().dimension() > 2 || other.getShape().dimension() > 2) {
            return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.train.batchMatMul(this.asOperand(), ((TfNDArray)other).asOperand(), new BatchMatMul.Options[0]));
        }
        BroadcastTo lhs = this.asOperand();
        BroadcastTo rhs = ((TfNDArray)other).asOperand();
        boolean broadcast = false;
        if (this.getShape().dimension() == 1) {
            lhs = this.tf.broadcastTo(this.asOperand(), (Operand)this.tf.constant(new long[]{1L, this.getShape().get(0)}));
            broadcast = true;
        }
        if (other.getShape().dimension() == 1) {
            rhs = this.tf.broadcastTo(((TfNDArray)other).asOperand(), (Operand)this.tf.constant(new long[]{1L, this.getShape().get(0)}));
            broadcast = true;
        }
        if (broadcast) {
            return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.linalg.matMul(lhs, rhs, new MatMul.Options[0])).squeeze();
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.linalg.matMul(lhs, rhs, new MatMul.Options[0]));
    }

    public NDArray clip(Number min, Number max) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.clipByValue(this.asOperand(), this.toConstant(min, this.getDataType()), this.toConstant(max, this.getDataType())));
    }

    public NDArray transpose() {
        int dim = this.getShape().dimension();
        int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray();
        return this.transpose(reversedShape);
    }

    public NDArray transpose(int ... dimensions) {
        if (Arrays.stream(dimensions).anyMatch(d -> d < 0)) {
            throw new UnsupportedOperationException("Passing -1 for broadcasting the dimension is not currently supported");
        }
        if (!Arrays.equals(Arrays.stream(dimensions).sorted().toArray(), IntStream.range(0, this.getShape().dimension()).toArray())) {
            throw new IllegalArgumentException("You must include each of the dimensions from 0 until " + this.getShape().dimension());
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.linalg.transpose(this.asOperand(), (Operand)this.tf.constant(dimensions)));
    }

    public NDArray broadcast(ai.djl.ndarray.types.Shape shape) {
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.broadcastTo(this.asOperand(), (Operand)this.tf.constant(shape.getShape())));
    }

    public NDArray argMax() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return this.flatten().argMax(0);
    }

    public NDArray argMax(int axis) {
        if (this.isScalar()) {
            return this.manager.create(0L);
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.argMax(this.asOperand(), (Operand)this.tf.constant(axis)));
    }

    public NDArray argMin() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return this.flatten().argMin(0);
    }

    public NDArray argMin(int axis) {
        if (this.isScalar()) {
            return this.manager.create(0L);
        }
        return new TfNDArray((NDManager)this.manager, (Operand<?>)this.tf.math.argMin(this.asOperand(), (Operand)this.tf.constant(axis)));
    }

    public NDArray percentile(Number percentile) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray percentile(Number percentile, int[] dimension) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray median() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray median(int[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray toDense() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray nonzero() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArrayEx getNDArrayInternal() {
        return this.tfNDArrayEx;
    }

    public boolean equals(Object obj) {
        if (obj instanceof TfNDArray) {
            return this.contentEquals((TfNDArray)obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    public String toString() {
        if (this.tensor == null) {
            return "This array is already closed";
        }
        return this.toDebugString(100, 10, 10, 20);
    }

    public void close() {
        if (this.tensor != null) {
            this.tensor.close();
        }
        this.tensor = null;
        this.tf = null;
        this.operand = null;
        this.tfNDArrayEx = null;
    }

    <T extends TType> Operand<T> asOperand() {
        if (this.operand == null) {
            this.operand = this.tf.constant(this.tensor);
        }
        return this.operand;
    }

    public Tensor<?> getTensor() {
        return this.tensor;
    }

    void setTensor(Tensor<?> tensor) {
        this.tensor = tensor;
    }

    void clearOperand() {
        this.operand = null;
    }

    int getRank() {
        return this.tf.rank(this.asOperand()).asOutput().tensor().rawData().asInts().getInt(0L);
    }

    private <T extends TType> Constant<T> toConstant(Number n, DataType jType) {
        return TfNDArray.getConstant(n, jType, this.tf);
    }

    public static Shape toTfShape(ai.djl.ndarray.types.Shape shape) {
        return Shape.of((long[])shape.getShape());
    }

    public static ByteDataBuffer toDataBuffer(FloatBuffer buffer) {
        ByteBuffer bb = ByteBuffer.allocate(buffer.remaining() * 4);
        bb.asFloatBuffer().put(buffer);
        return DataBuffers.of((ByteBuffer)bb);
    }

    static <T extends TType> Constant<T> getConstant(Number n, DataType jType, Ops tf) {
        switch (jType) {
            case INT8: {
                return tf.constant(n.byteValue());
            }
            case INT32: {
                return tf.constant(n.intValue());
            }
            case INT64: {
                return tf.constant(n.longValue());
            }
            case FLOAT16: {
                return tf.constant((int)n.shortValue());
            }
            case FLOAT32: {
                return tf.constant(n.floatValue());
            }
            case FLOAT64: {
                return tf.constant(n.doubleValue());
            }
        }
        throw new EngineException("unsupported type");
    }
}

