/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxPredictor;
import ai.djl.mxnet.engine.MxSymbolBlock;
import ai.djl.mxnet.engine.MxTrainer;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.translate.Translator;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MxModel
implements Model {
    private static final Logger logger = LoggerFactory.getLogger(MxModel.class);
    private static final int MODEL_VERSION = 1;
    private Path modelDir;
    private String modelName;
    private MxNDManager manager;
    private Block block;
    private DataType dataType;
    private Map<String, String> properties;
    private PairList<String, Shape> inputData;
    private Map<String, Object> artifacts = new ConcurrentHashMap<String, Object>();
    private AtomicBoolean first;

    MxModel(Device device) {
        device = Device.defaultIfNull((Device)device);
        this.dataType = DataType.FLOAT32;
        this.properties = new ConcurrentHashMap<String, String>();
        this.manager = MxNDManager.getSystemManager().newSubManager(device);
        this.first = new AtomicBoolean(true);
    }

    public void load(Path modelPath, String modelName, Map<String, String> options) throws IOException, MalformedModelException {
        this.modelDir = modelPath.toAbsolutePath();
        this.modelName = modelName;
        if (this.block == null) {
            Path symbolFile = this.modelDir.resolve(modelName + "-symbol.json");
            if (Files.notExists(symbolFile, new LinkOption[0])) {
                throw new FileNotFoundException("Symbol file not found in: " + modelPath + ", please set block manually.");
            }
            Symbol symbol = Symbol.load(this.manager, symbolFile.toAbsolutePath().toString());
            this.block = new MxSymbolBlock(this.manager, symbol);
        }
        this.loadParameters(modelName, options);
    }

    public void save(Path modelPath, String modelName) throws IOException {
        if (Files.notExists(modelPath, new LinkOption[0])) {
            Files.createDirectories(modelPath, new FileAttribute[0]);
        }
        if (this.block == null || !this.block.isInitialized()) {
            throw new IllegalStateException("Model has not be trained or loaded yet.");
        }
        String epochValue = this.getProperty("Epoch");
        int epoch = epochValue == null ? Utils.getCurrentEpoch((Path)modelPath, (String)modelName) + 1 : Integer.parseInt(epochValue);
        Path paramFile = modelPath.resolve(String.format("%s-%04d.params", modelName, epoch));
        try (DataOutputStream dos = new DataOutputStream(Files.newOutputStream(paramFile, new OpenOption[0]));){
            dos.writeBytes("DJL@");
            dos.writeInt(1);
            dos.writeUTF(modelName);
            dos.writeUTF(this.dataType.name());
            this.inputData = this.block.describeInput();
            dos.writeInt(this.inputData.size());
            for (Pair desc : this.inputData) {
                String name = (String)desc.getKey();
                if (name == null) {
                    dos.writeUTF("");
                } else {
                    dos.writeUTF(name);
                }
                dos.write(((Shape)desc.getValue()).getEncoded());
            }
            dos.writeInt(this.properties.size());
            for (Map.Entry<String, String> entry : this.properties.entrySet()) {
                dos.writeUTF(entry.getKey());
                dos.writeUTF(entry.getValue());
            }
            this.block.saveParameters(dos);
        }
        this.modelName = modelName;
        this.modelDir = modelPath.toAbsolutePath();
    }

    public Block getBlock() {
        return this.block;
    }

    public void setBlock(Block block) {
        this.block = block;
    }

    public String getName() {
        return this.modelName;
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        Initializer initializer = trainingConfig.getInitializer();
        this.block.setInitializer(initializer);
        return new MxTrainer(this, trainingConfig);
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        boolean firstPredictor = this.first.getAndSet(false);
        boolean shouldCopyParameters = !JnaUtils.useThreadSafePredictor() && !firstPredictor;
        return new MxPredictor<I, O>(this, translator, shouldCopyParameters);
    }

    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }

    public DataType getDataType() {
        return this.dataType;
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public PairList<String, Shape> describeInput() {
        if (this.inputData == null) {
            this.inputData = this.block.describeInput();
        }
        return this.inputData;
    }

    public PairList<String, Shape> describeOutput() {
        List names = this.inputData.keys();
        Shape[] outputShapes = this.block.getOutputShapes((NDManager)this.manager, this.inputData.values().toArray(new Shape[this.inputData.size()]));
        return new PairList(names, Arrays.asList(outputShapes));
    }

    public String[] getArtifactNames() {
        try {
            List files = Files.walk(this.modelDir, new FileVisitOption[0]).filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
            ArrayList<String> ret = new ArrayList<String>(files.size());
            for (Path path : files) {
                String fileName = path.toFile().getName();
                if (fileName.endsWith(".params") || fileName.endsWith("-symbol.json")) continue;
                Path relative = this.modelDir.relativize(path);
                ret.add(relative.toString());
            }
            return ret.toArray(new String[0]);
        }
        catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public <T> T getArtifact(String name, Function<InputStream, T> function) throws IOException {
        try {
            Object artifact = this.artifacts.computeIfAbsent(name, v -> {
                try (InputStream is = this.getArtifactAsStream(name);){
                    Object r = function.apply(is);
                    return r;
                }
                catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            });
            return (T)artifact;
        }
        catch (RuntimeException e) {
            Throwable t = e.getCause();
            if (t instanceof IOException) {
                throw (IOException)e.getCause();
            }
            throw e;
        }
    }

    public URL getArtifact(String artifactName) throws IOException {
        if (artifactName == null) {
            throw new IllegalArgumentException("artifactName cannot be null");
        }
        Path file = this.modelDir.resolve(artifactName);
        if (Files.exists(file, new LinkOption[0]) && Files.isReadable(file)) {
            return file.toUri().toURL();
        }
        throw new FileNotFoundException("File not found: " + file);
    }

    public InputStream getArtifactAsStream(String name) throws IOException {
        URL url = this.getArtifact(name);
        return url.openStream();
    }

    public NDManager getNDManager() {
        return this.manager;
    }

    public void setProperty(String key, String value) {
        this.properties.put(key, value);
    }

    public String getProperty(String key) {
        return this.properties.get(key);
    }

    public void close() {
        JnaUtils.waitAll();
        this.manager.close();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            logger.warn("MxModel was not closed explicitly.");
            this.manager.close();
        }
        super.finalize();
    }

    private void loadParameters(String modelName, Map<String, String> options) throws IOException, MalformedModelException {
        Path paramFile;
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            paramFile = this.modelDir;
        } else {
            int epoch;
            String epochOption = null;
            if (options != null) {
                epochOption = options.get("epoch");
            }
            if (epochOption == null) {
                epoch = Utils.getCurrentEpoch((Path)this.modelDir, (String)modelName);
                if (epoch == -1) {
                    throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name matchyour saved model file name.");
                }
            } else {
                epoch = Integer.parseInt(epochOption);
            }
            paramFile = this.modelDir.resolve(String.format("%s-%04d.params", modelName, epoch));
        }
        logger.debug("Try to load model from {}", (Object)paramFile);
        if (this.readParameters(paramFile)) {
            return;
        }
        logger.debug("DJL formatted model not found, try to find MXNet model");
        NDList paramNDlist = JnaUtils.loadNdArray(this.manager, paramFile.toAbsolutePath(), this.manager.getDevice());
        MxSymbolBlock symbolBlock = (MxSymbolBlock)this.block;
        List<Parameter> parameters = symbolBlock.getAllParameters();
        LinkedHashMap map = new LinkedHashMap();
        parameters.forEach(p -> map.put(p.getName(), p));
        for (NDArray nd : paramNDlist) {
            String key = nd.getName();
            if (key == null) {
                throw new IllegalArgumentException("Array names must be present in parameter file");
            }
            String paramName = key.split(":", 2)[1];
            Parameter parameter = (Parameter)map.remove(paramName);
            parameter.setArray(nd);
        }
        symbolBlock.setInputNames(new ArrayList<String>(map.keySet()));
        this.dataType = paramNDlist.head().getDataType();
        logger.debug("MXNet Model {} ({}) loaded successfully.", (Object)modelName, (Object)this.dataType);
    }

    private boolean readParameters(Path paramFile) throws IOException, MalformedModelException {
        try (DataInputStream dis = new DataInputStream(Files.newInputStream(paramFile, new OpenOption[0]));){
            byte[] buf = new byte[4];
            dis.readFully(buf);
            if (!"DJL@".equals(new String(buf, StandardCharsets.US_ASCII))) {
                boolean bl = false;
                return bl;
            }
            int version = dis.readInt();
            if (version != 1) {
                throw new IOException("Unsupported model version: " + version);
            }
            this.modelName = dis.readUTF();
            logger.debug("Loading model parameter: {}", (Object)this.modelName);
            this.dataType = DataType.valueOf((String)dis.readUTF());
            int numberOfInputs = dis.readInt();
            this.inputData = new PairList();
            for (int i = 0; i < numberOfInputs; ++i) {
                String inputName = dis.readUTF();
                Shape shape = Shape.decode((DataInputStream)dis);
                this.inputData.add((Object)inputName, (Object)shape);
            }
            int numberOfProperties = dis.readInt();
            for (int i = 0; i < numberOfProperties; ++i) {
                String key = dis.readUTF();
                String value = dis.readUTF();
                this.properties.put(key, value);
            }
            this.block.loadParameters((NDManager)this.manager, dis);
            logger.debug("DJL model loaded successfully");
        }
        return true;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Model (\n\tName: ").append(this.modelName);
        if (this.modelDir != null) {
            sb.append("\n\tModel location: ").append(this.modelDir.toAbsolutePath());
        }
        sb.append("\n\tData Type: ").append(this.dataType);
        for (Map.Entry<String, String> entry : this.properties.entrySet()) {
            sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
        }
        sb.append("\n)");
        return sb.toString();
    }
}

