/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import java.util.function.Consumer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.DataType;
import org.tensorflow.DataTypes;
import org.tensorflow.EagerSession;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.buffer.TensorBuffers;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.tools.Shape;
import org.tensorflow.tools.buffer.ByteDataBuffer;
import org.tensorflow.tools.buffer.DataBuffer;
import org.tensorflow.types.family.TType;

public final class Tensor<T extends TType>
implements AutoCloseable {
    private PointerScope tensorScope;
    private TF_Tensor tensorHandle;
    private final DataType<T> dtype;
    private final Shape shape;
    private T data = null;
    private Long numBytes = null;

    public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape) {
        return Tensor.of(dtype, shape, shape.size() * (long)dtype.byteSize());
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape, long size) {
        if (!dtype.isVariableLength() && shape.size() * (long)dtype.byteSize() > size) {
            throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values");
        }
        Tensor<T> t = new Tensor<T>(dtype, shape);
        TF_Tensor nativeHandle = Tensor.allocate(t.dtype.nativeCode(), shape.asArray(), size);
        try (PointerScope scope = new PointerScope(new Class[0]);){
            scope.attach((Pointer)nativeHandle);
            t.tensorHandle = nativeHandle;
            t.tensorScope = scope.extend();
            Tensor<T> tensor = t;
            return tensor;
        }
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape, Consumer<T> dataInitializer) {
        return Tensor.of(dtype, shape, shape.size() * (long)dtype.byteSize(), dataInitializer);
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape, long size, Consumer<T> dataInitializer) {
        Tensor<T> tensor = Tensor.of(dtype, shape, size);
        try {
            dataInitializer.accept(tensor.data());
            return tensor;
        }
        catch (Throwable t) {
            tensor.close();
            throw t;
        }
    }

    public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape, ByteDataBuffer rawData) {
        Tensor<T> t = Tensor.of(dtype, shape, rawData.size());
        rawData.copyTo((DataBuffer)TensorBuffers.toBytes(t.nativeHandle()), rawData.size());
        return t;
    }

    public <U extends TType> Tensor<U> expect(DataType<U> dt) {
        if (!dt.equals(this.dtype)) {
            throw new IllegalArgumentException("Cannot cast from tensor of " + this.dtype + " to tensor of " + dt);
        }
        return this;
    }

    @Override
    public void close() {
        this.tensorScope.close();
    }

    public DataType<T> dataType() {
        return this.dtype;
    }

    public long numBytes() {
        if (this.numBytes == null) {
            this.numBytes = tensorflow.TF_TensorByteSize(this.tensorHandle);
        }
        return this.numBytes;
    }

    public Shape shape() {
        return this.shape;
    }

    public T data() {
        if (this.data == null) {
            this.data = this.dtype.map(this);
        } else {
            this.nativeHandle();
        }
        return this.data;
    }

    public ByteDataBuffer rawData() {
        return TensorBuffers.toBytes(this.nativeHandle(), true);
    }

    public String toString() {
        return String.format("%s tensor with shape %s", this.dtype.toString(), this.shape);
    }

    static Tensor<?> fromHandle(TF_Tensor handle) {
        Tensor t = new Tensor(DataTypes.fromNativeCode(Tensor.dtype(handle)), Shape.of(Tensor.shape(handle)));
        try (PointerScope scope = new PointerScope(new Class[0]);){
            scope.attach((Pointer)handle);
            t.tensorHandle = handle;
            t.tensorScope = scope.extend();
        }
        return t;
    }

    static Tensor<?> fromHandle(TF_Tensor handle, EagerSession session) {
        Tensor<?> t = Tensor.fromHandle(handle);
        session.attach(handle);
        t.tensorScope.detach((Pointer)handle);
        return t;
    }

    TF_Tensor nativeHandle() {
        return Tensor.requireHandle(this.tensorHandle);
    }

    private static TF_Tensor requireHandle(TF_Tensor handle) {
        if (handle == null || handle.isNull()) {
            throw new IllegalStateException("close() was called on the Tensor");
        }
        return handle;
    }

    private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) {
        TF_Tensor t = TF_Tensor.allocateTensor(dtype, shape, byteSize);
        if (t == null || t.isNull()) {
            throw new IllegalStateException("unable to allocate memory for the Tensor");
        }
        return t;
    }

    private static int dtype(TF_Tensor handle) {
        Tensor.requireHandle(handle);
        return tensorflow.TF_TensorType(handle);
    }

    private static long[] shape(TF_Tensor handle) {
        Tensor.requireHandle(handle);
        int numDims = tensorflow.TF_NumDims(handle);
        long[] dims = new long[numDims];
        for (int i = 0; i < numDims; ++i) {
            dims[i] = tensorflow.TF_Dim(handle, i);
        }
        return dims;
    }

    private Tensor(DataType<T> dtype, Shape shape) {
        this.dtype = dtype;
        this.shape = shape;
    }

    static {
        TensorFlow.init();
    }
}

