/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxSymbolBlock;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.translate.Translator;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MxModel
extends BaseModel {
    private static final Logger logger = LoggerFactory.getLogger(MxModel.class);

    MxModel(String name, Device device) {
        super(name);
        device = Device.defaultIfNull((Device)device);
        this.dataType = DataType.FLOAT32;
        this.properties = new ConcurrentHashMap();
        this.manager = MxNDManager.getSystemManager().newSubManager(device);
    }

    public void load(Path modelPath, String prefix, Map<String, Object> options) throws IOException, MalformedModelException {
        Path paramFile;
        this.modelDir = modelPath.toAbsolutePath();
        if (prefix == null) {
            prefix = this.modelName;
        }
        if ((paramFile = this.paramPathResolver(prefix, options)) == null && (paramFile = this.paramPathResolver(prefix = this.modelDir.toFile().getName(), options)) == null) {
            throw new IOException("Parameter file not found in: " + this.modelDir);
        }
        if (this.block == null) {
            Path symbolFile = this.modelDir.resolve(prefix + "-symbol.json");
            if (Files.notExists(symbolFile, new LinkOption[0])) {
                throw new FileNotFoundException("Symbol file not found: " + symbolFile + ", please set block manually for imperative model.");
            }
            Symbol symbol = Symbol.load((MxNDManager)this.manager, symbolFile.toAbsolutePath().toString());
            this.block = new MxSymbolBlock(this.manager, symbol);
        }
        this.loadParameters(paramFile, options);
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        Initializer initializer = trainingConfig.getInitializer();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        this.block.setInitializer(initializer);
        return new Trainer((Model)this, trainingConfig);
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new Predictor((Model)this, translator, false);
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String[] getArtifactNames() {
        try {
            List files = Files.walk(this.modelDir, new FileVisitOption[0]).filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
            ArrayList<String> ret = new ArrayList<String>(files.size());
            for (Path path : files) {
                String fileName = path.toFile().getName();
                if (fileName.endsWith(".params") || fileName.endsWith("-symbol.json")) continue;
                Path relative = this.modelDir.relativize(path);
                ret.add(relative.toString());
            }
            return ret.toArray(new String[0]);
        }
        catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public void close() {
        JnaUtils.waitAll();
        this.manager.close();
    }

    private void loadParameters(Path paramFile, Map<String, Object> options) throws IOException, MalformedModelException {
        if (this.readParameters(paramFile, options)) {
            return;
        }
        logger.debug("DJL formatted model not found, try to find MXNet model");
        NDList paramNDlist = this.manager.load(paramFile);
        MxSymbolBlock symbolBlock = (MxSymbolBlock)this.block;
        List<Parameter> parameters = symbolBlock.getAllParameters();
        LinkedHashMap map = new LinkedHashMap();
        parameters.forEach(p -> map.put(p.getName(), p));
        for (NDArray nd : paramNDlist) {
            String key = nd.getName();
            if (key == null) {
                throw new IllegalArgumentException("Array names must be present in parameter file");
            }
            String paramName = key.split(":", 2)[1];
            Parameter parameter = (Parameter)map.remove(paramName);
            parameter.setArray(nd);
        }
        symbolBlock.setInputNames(new ArrayList<String>(map.keySet()));
        this.dataType = paramNDlist.head().getDataType();
        logger.debug("MXNet Model {} ({}) loaded successfully.", (Object)paramFile, (Object)this.dataType);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Model (\n\tName: ").append(this.modelName);
        if (this.modelDir != null) {
            sb.append("\n\tModel location: ").append(this.modelDir.toAbsolutePath());
        }
        sb.append("\n\tData Type: ").append(this.dataType);
        for (Map.Entry entry : this.properties.entrySet()) {
            sb.append("\n\t").append((String)entry.getKey()).append(": ").append((String)entry.getValue());
        }
        sb.append("\n)");
        return sb.toString();
    }
}

