/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.engine;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.tensorflow.engine.TfNDArray;
import ai.djl.tensorflow.engine.TfNDManager;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import java.util.Map;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.framework.TensorShapeProto;

public class TfSymbolBlock
implements SymbolBlock {
    private SavedModelBundle bundle;
    private MetaGraphDef metaGraphDef;
    private Session session;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;

    public TfSymbolBlock(SavedModelBundle bundle) {
        this.bundle = bundle;
        this.session = bundle.session();
        this.metaGraphDef = bundle.metaGraphDef();
    }

    public void removeLastBlock() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        int i;
        Session.Runner runner = this.session.runner();
        this.describeInput();
        this.describeOutput();
        for (i = 0; i < this.inputDescriptions.size(); ++i) {
            runner.feed((String)this.inputDescriptions.get(i).getKey(), ((TfNDArray)inputs.get(i)).getTensor());
        }
        for (i = 0; i < this.outputDescriptions.size(); ++i) {
            runner.fetch((String)this.outputDescriptions.get(i).getKey());
        }
        List result = runner.run();
        NDList resultNDList = new NDList();
        TfNDManager tfNDManager = (TfNDManager)inputs.head().getManager();
        for (Tensor tensor : result) {
            resultNDList.add((Object)tfNDManager.create(tensor));
            tensor.close();
        }
        return resultNDList;
    }

    public void setInitializer(Initializer initializer) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void setInitializer(Initializer initializer, String paramName) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape[] initialize(NDManager manager, ai.djl.ndarray.types.DataType dataType, Shape ... inputShapes) {
        return new Shape[0];
    }

    public boolean isInitialized() {
        return this.bundle != null;
    }

    public void cast(ai.djl.ndarray.types.DataType dataType) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void clear() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.bundle != null) {
            this.bundle.close();
        }
    }

    public PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            this.inputDescriptions = new PairList();
            Map signatureDefMap = this.metaGraphDef.getSignatureDefMap();
            SignatureDef servingDefault = (SignatureDef)signatureDefMap.entrySet().iterator().next().getValue();
            for (Map.Entry entry : servingDefault.getInputsMap().entrySet()) {
                TensorShapeProto shapeProto = ((TensorInfo)entry.getValue()).getTensorShape();
                this.inputDescriptions.add((Object)((TensorInfo)entry.getValue()).getName(), (Object)new Shape(shapeProto.getDimList().stream().mapToLong(TensorShapeProto.Dim::getSize).toArray()));
            }
        }
        return this.inputDescriptions;
    }

    PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            this.outputDescriptions = new PairList();
            Map signatureDefMap = this.metaGraphDef.getSignatureDefMap();
            SignatureDef servingDefault = (SignatureDef)signatureDefMap.entrySet().iterator().next().getValue();
            for (Map.Entry entry : servingDefault.getOutputsMap().entrySet()) {
                TensorShapeProto shapeProto = ((TensorInfo)entry.getValue()).getTensorShape();
                if (((TensorInfo)entry.getValue()).getDtype() == DataType.DT_STRING) continue;
                this.outputDescriptions.add((Object)((TensorInfo)entry.getValue()).getName(), (Object)new Shape(shapeProto.getDimList().stream().mapToLong(TensorShapeProto.Dim::getSize).toArray()));
            }
        }
        return this.outputDescriptions;
    }

    public BlockList getChildren() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public ParameterList getParameters() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape getParameterShape(String name, Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[0];
    }

    public void saveParameters(DataOutputStream os) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void loadParameters(NDManager manager, DataInputStream is) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }
}

