/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFunction;

public class SessionFunction
implements TensorFunction {
    private final Signature signature;
    private final Session session;

    public SessionFunction(Signature signature, Session session) {
        this.signature = signature;
        this.session = session;
        signature.getInputs().forEach((name, description) -> TensorFunction.validateDescription(description, session.graph(), name, "Input"));
        signature.getInputs().forEach((name, description) -> TensorFunction.validateDescription(description, session.graph(), name, "Output"));
    }

    public static SessionFunction create(Signature signature, Session session) {
        return new SessionFunction(signature, session);
    }

    public void save(String exportDir) throws IOException {
        SavedModelBundle.exporter(exportDir).withFunction(this).export();
    }

    @Override
    public Signature signature() {
        return this.signature;
    }

    public Session session() {
        return this.session;
    }

    public SessionFunction withNewSession(Session session) {
        return new SessionFunction(this.signature, session);
    }

    @Override
    public Map<String, Tensor> call(Map<String, Tensor> arguments) {
        Session.Runner runner = this.session.runner();
        this.signature.getInputs().forEach((argName, operand) -> {
            if (!arguments.containsKey(argName)) {
                throw new IllegalArgumentException("No argument found for parameter \"" + argName + "\"");
            }
            Tensor value = (Tensor)arguments.get(argName);
            if (value == null) {
                throw new IllegalArgumentException("Can't pass null as an argument to a function.  Argument \"" + argName + "\" was null.");
            }
            runner.feed(operand.name, value);
        });
        this.signature.getOutputs().values().forEach(x -> runner.fetch(x.name));
        List<Tensor> results = runner.run();
        LinkedHashMap<String, Tensor> outputs = new LinkedHashMap<String, Tensor>(results.size());
        int i = 0;
        for (String outputName : this.signature.outputNames()) {
            outputs.put(outputName, results.get(i));
            ++i;
        }
        return outputs;
    }
}

