/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Set;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.NativeLibrary;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.internal.c_api.GradFunc;
import org.tensorflow.internal.c_api.GradOpRegistry;
import org.tensorflow.internal.c_api.NativeStatus;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Library;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.op.CustomGradient;
import org.tensorflow.op.RawCustomGradient;
import org.tensorflow.op.RawOpInputs;
import org.tensorflow.op.annotation.OpInputsMetadata;
import org.tensorflow.op.annotation.OpMetadata;
import org.tensorflow.proto.framework.OpList;

public final class TensorFlow {
    private static Set<String> statefulOps;
    private static final Set<GradFunc> gradientFuncs;

    public static String version() {
        return tensorflow.TF_Version().getString();
    }

    public static OpList registeredOpList() {
        TF_Buffer buf = tensorflow.TF_GetAllOpList();
        try {
            OpList opList = OpList.parseFrom(buf.dataAsByteBuffer());
            return opList;
        }
        catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse OpList protocol buffer", e);
        }
        finally {
            tensorflow.TF_DeleteBuffer(buf);
        }
    }

    public static synchronized boolean isOpStateful(String opType) {
        if (statefulOps == null) {
            statefulOps = TensorFlow.registeredOpList().getOpList().stream().filter(x -> x.getIsStateful()).map(x -> x.getName()).collect(Collectors.toSet());
        }
        return statefulOps.contains(opType);
    }

    public static OpList loadLibrary(String filename) {
        TF_Library h = null;
        try {
            h = TensorFlow.libraryLoad(filename);
        }
        catch (RuntimeException e) {
            throw new UnsatisfiedLinkError(e.getMessage());
        }
        try {
            OpList opList = TensorFlow.libraryOpList(h);
            return opList;
        }
        finally {
            TensorFlow.libraryDelete(h);
        }
    }

    private static TF_Library libraryLoad(String filename) {
        try (PointerScope scope = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            TF_Library h = tensorflow.TF_LoadLibrary(filename, status);
            status.throwExceptionIfNotOK();
            TF_Library tF_Library = h;
            return tF_Library;
        }
    }

    private static void libraryDelete(TF_Library handle) {
        if (handle != null && !handle.isNull()) {
            tensorflow.TF_DeleteLibraryHandle(handle);
        }
    }

    private static OpList libraryOpList(TF_Library handle) {
        TF_Buffer buf = tensorflow.TF_GetOpList(handle);
        try {
            return OpList.parseFrom(buf.dataAsByteBuffer());
        }
        catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse OpList protocol buffer", e);
        }
    }

    private TensorFlow() {
    }

    private static synchronized boolean hasGradient(String opType) {
        try (PointerScope scope = new PointerScope(new Class[0]);){
            NativeStatus status = GradOpRegistry.Global().Lookup(opType, new GradFunc((Pointer)new PointerPointer(1L)));
            boolean bl = status.ok();
            return bl;
        }
    }

    public static synchronized boolean registerCustomGradient(String opType, RawCustomGradient gradient) {
        if (TensorFlow.hasGradient(opType)) {
            return false;
        }
        GradFunc g = RawCustomGradient.adapter(gradient);
        GradOpRegistry.Global().Register(opType, g);
        gradientFuncs.add(g);
        return true;
    }

    public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(Class<T> inputClass, CustomGradient<T> gradient) {
        OpInputsMetadata metadata = inputClass.getAnnotation(OpInputsMetadata.class);
        if (metadata == null) {
            throw new IllegalArgumentException("Inputs Class " + inputClass + " does not have a OpInputsMetadata annotation.  Was it generated by tensorflow/java?  If it was, this is a bug.");
        }
        OpMetadata outputMetadata = metadata.outputsClass().getAnnotation(OpMetadata.class);
        if (outputMetadata == null) {
            throw new IllegalArgumentException("Op Class " + metadata.outputsClass() + " does not have a OpMetadata annotation.  Was it generated by tensorflow/java?  If it was, this is a bug.");
        }
        String opType = outputMetadata.opType();
        if (TensorFlow.hasGradient(opType)) {
            return false;
        }
        GradFunc g = CustomGradient.adapter(gradient, inputClass);
        GradOpRegistry.Global().Register(opType, g);
        gradientFuncs.add(g);
        return true;
    }

    static {
        try {
            NativeLibrary.load();
        }
        catch (Exception e) {
            System.err.println("Failed to load TensorFlow native library");
            e.printStackTrace();
            throw e;
        }
        gradientFuncs = Collections.newSetFromMap(new IdentityHashMap());
    }
}

