/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDManager;
import ai.djl.tensorflow.engine.LibUtils;
import ai.djl.tensorflow.engine.TfModel;
import ai.djl.tensorflow.engine.TfNDManager;
import ai.djl.training.GradientCollector;
import ai.djl.util.RandomUtils;
import org.tensorflow.EagerSession;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_DeviceList;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;

public final class TfEngine
extends Engine {
    public static final String ENGINE_NAME = "TensorFlow";

    private TfEngine() {
    }

    static TfEngine newInstance() {
        try {
            LibUtils.loadLibrary();
            EagerSession.getDefault();
            return new TfEngine();
        }
        catch (Throwable t) {
            throw new EngineException("Failed to load TensorFlow native library", t);
        }
    }

    public Model newModel(String name, Device device) {
        return new TfModel(name, device);
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public String getVersion() {
        return TensorFlow.version();
    }

    public boolean hasCapability(String capability) {
        if ("MKL".equals(capability)) {
            return true;
        }
        if ("CUDA".equals(capability)) {
            TF_Status status = tensorflow.TF_NewStatus();
            TF_DeviceList deviceList = tensorflow.TFE_ContextListDevices((TFE_Context)tensorflow.TFE_NewContext((TFE_ContextOptions)tensorflow.TFE_NewContextOptions(), (TF_Status)status), (TF_Status)status);
            int deviceCount = tensorflow.TF_DeviceListCount((TF_DeviceList)deviceList);
            for (int i = 0; i < deviceCount; ++i) {
                if (!tensorflow.TF_DeviceListName((TF_DeviceList)deviceList, (int)i, (TF_Status)status).getString().toLowerCase().contains("gpu")) continue;
                return true;
            }
            return false;
        }
        return false;
    }

    public NDManager newBaseManager() {
        return TfNDManager.getSystemManager().newSubManager();
    }

    public NDManager newBaseManager(Device device) {
        return TfNDManager.getSystemManager().newSubManager(device);
    }

    public GradientCollector newGradientCollector() {
        throw new UnsupportedOperationException("TensorFlow does not support training yet");
    }

    public void setRandomSeed(int seed) {
        TfNDManager.setRandomSeed(seed);
        RandomUtils.RANDOM.setSeed(seed);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.getEngineName()).append(':').append(this.getVersion()).append(", capabilities: [\n\tMKL,\n");
        if (this.hasCapability("CUDA")) {
            sb.append("\t").append("CUDA").append(",\n");
        }
        sb.append("]\nTensorFlow Library: ").append(LibUtils.getLibName());
        return sb.toString();
    }
}

