package io.spokestack.spokestack.tensorflow;

import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.tensorflow.lite.Interpreter;

/* loaded from: input_file:io/spokestack/spokestack/tensorflow/TensorflowModel.class */
public class TensorflowModel implements AutoCloseable {
    private final Interpreter interpreter;
    private final List<ByteBuffer> inputBuffers = new ArrayList();
    private final List<ByteBuffer> outputBuffers = new ArrayList();
    private final int inputSize;
    private final Object[] inputArray;
    private final Map<Integer, Object> outputMap;
    private Integer statePosition;

    /* loaded from: input_file:io/spokestack/spokestack/tensorflow/TensorflowModel$Loader.class */
    public static class Loader {
        private String path;
        private int inputSize;
        private int outputSize;
        private Integer statePosition = null;

        /* loaded from: input_file:io/spokestack/spokestack/tensorflow/TensorflowModel$Loader$DType.class */
        public enum DType {
            FLOAT
        }

        public Loader() {
            reset();
        }

        public Loader reset() {
            this.path = null;
            this.inputSize = 4;
            this.outputSize = 4;
            this.statePosition = null;
            return this;
        }

        public Loader setPath(String str) {
            this.path = str;
            return this;
        }

        public Loader setStatePosition(int i) {
            this.statePosition = Integer.valueOf(i);
            return this;
        }

        public TensorflowModel load() {
            TensorflowModel tensorflowModel = new TensorflowModel(this);
            reset();
            return tensorflowModel;
        }
    }

    public TensorflowModel(Loader loader) {
        this.interpreter = new Interpreter(new File(loader.path));
        for (int i = 0; i < this.interpreter.getInputTensorCount(); i++) {
            this.inputBuffers.add(ByteBuffer.allocateDirect(combineShape(this.interpreter.getInputTensor(i).shape()) * loader.inputSize).order(ByteOrder.nativeOrder()));
        }
        for (int i2 = 0; i2 < this.interpreter.getOutputTensorCount(); i2++) {
            this.outputBuffers.add(ByteBuffer.allocateDirect(combineShape(this.interpreter.getOutputTensor(i2).shape()) * loader.outputSize).order(ByteOrder.nativeOrder()));
        }
        this.inputSize = loader.inputSize;
        this.statePosition = loader.statePosition;
        this.inputArray = new Object[this.inputBuffers.size()];
        this.outputMap = new HashMap();
    }

    private int combineShape(int[] iArr) {
        int i = 1;
        for (int i2 : iArr) {
            i *= i2;
        }
        return i;
    }

    public int getInputSize() {
        return this.inputSize;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.interpreter.close();
    }

    public ByteBuffer inputs(int i) {
        return this.inputBuffers.get(i);
    }

    public ByteBuffer states() {
        if (this.statePosition == null) {
            return null;
        }
        return this.inputBuffers.get(this.statePosition.intValue());
    }

    public ByteBuffer outputs(int i) {
        return this.outputBuffers.get(i);
    }

    public void run() {
        Iterator<ByteBuffer> it = this.inputBuffers.iterator();
        while (it.hasNext()) {
            it.next().rewind();
        }
        Iterator<ByteBuffer> it2 = this.outputBuffers.iterator();
        while (it2.hasNext()) {
            it2.next().rewind();
        }
        for (int i = 0; i < this.inputBuffers.size(); i++) {
            this.inputArray[i] = this.inputBuffers.get(i);
        }
        for (int i2 = 0; i2 < this.outputBuffers.size(); i2++) {
            this.outputMap.put(Integer.valueOf(i2), this.outputBuffers.get(i2));
        }
        this.interpreter.runForMultipleInputsOutputs(this.inputArray, this.outputMap);
        if (this.statePosition != null) {
            ByteBuffer remove = this.inputBuffers.remove(this.statePosition.intValue());
            this.inputBuffers.add(this.statePosition.intValue(), this.outputBuffers.remove(this.statePosition.intValue()));
            this.outputBuffers.add(this.statePosition.intValue(), remove);
        }
        Iterator<ByteBuffer> it3 = this.inputBuffers.iterator();
        while (it3.hasNext()) {
            it3.next().rewind();
        }
        Iterator<ByteBuffer> it4 = this.outputBuffers.iterator();
        while (it4.hasNext()) {
            it4.next().rewind();
        }
    }
}
