/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.repository.zoo;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ndarray.NDList;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Metadata;
import ai.djl.repository.Repository;
import ai.djl.repository.VersionRange;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public abstract class BaseModelLoader<I, O>
implements ModelLoader<I, O> {
    protected Repository repository;
    protected MRL mrl;
    protected String version;
    protected Map<Pair<Type, Type>, TranslatorFactory<?, ?>> factories;
    private Metadata metadata;

    protected BaseModelLoader(Repository repository, MRL mrl, String version) {
        this.repository = repository;
        this.mrl = mrl;
        this.version = version;
        this.factories = new ConcurrentHashMap();
        this.factories.put(new Pair<Class<NDList>, Class<NDList>>(NDList.class, NDList.class), arguments -> new NoopTranslator());
    }

    @Override
    public String getArtifactId() {
        return this.mrl.getArtifactId();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public <S, T> ZooModel<S, T> loadModel(Criteria<S, T> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
        Artifact artifact = this.match(criteria.getFilters());
        if (artifact == null) {
            throw new ModelNotFoundException("Model not found.");
        }
        Map<String, Object> override = criteria.getArguments();
        Progress progress = criteria.getProgress();
        Map<String, Object> arguments = artifact.getArguments(override);
        try {
            Translator<S, T> translator = criteria.getTranslator();
            if (translator == null) {
                TranslatorFactory<S, T> factory = this.getTranslatorFactory(criteria);
                if (factory == null) {
                    throw new ModelNotFoundException("No matching default translator found.");
                }
                translator = factory.newInstance(arguments);
            }
            this.repository.prepare(artifact, progress);
            if (progress != null) {
                progress.reset("Loading", 2L);
                progress.update(1L);
            }
            Path modelPath = this.repository.getResourceDirectory(artifact);
            Model model = this.createModel(criteria.getDevice(), artifact, arguments);
            model.load(modelPath, artifact.getName(), criteria.getOptions());
            ZooModel<S, T> zooModel = new ZooModel<S, T>(model, translator);
            return zooModel;
        }
        finally {
            if (progress != null) {
                progress.end();
            }
        }
    }

    @Override
    public List<Artifact> listModels() throws IOException, ModelNotFoundException {
        List<Artifact> list = this.getMetadata().getArtifacts();
        return list.stream().filter(a -> this.version == null || this.version.equals(a.getVersion())).collect(Collectors.toList());
    }

    protected Model createModel(Device device, Artifact artifact, Map<String, Object> arguments) throws IOException {
        return Model.newInstance(device);
    }

    protected Artifact match(Map<String, String> criteria) throws IOException, ModelNotFoundException {
        List<Artifact> list = this.search(criteria);
        if (list.isEmpty()) {
            return null;
        }
        return list.get(0);
    }

    private List<Artifact> search(Map<String, String> criteria) throws IOException, ModelNotFoundException {
        return this.getMetadata().search(VersionRange.parse(this.version), criteria);
    }

    private Metadata getMetadata() throws IOException, ModelNotFoundException {
        if (this.metadata == null) {
            this.metadata = this.repository.locate(this.mrl);
            if (this.metadata == null) {
                throw new ModelNotFoundException(this.mrl.getArtifactId() + " Models not found.");
            }
        }
        return this.metadata;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.repository.getName()).append(':').append(this.mrl.getGroupId()).append(':').append(this.mrl.getArtifactId()).append(" [\n");
        try {
            for (Artifact artifact : this.listModels()) {
                sb.append('\t').append(artifact).append('\n');
            }
        }
        catch (ModelNotFoundException | IOException e) {
            sb.append("\tFailed load metadata.");
        }
        sb.append("\n]");
        return sb.toString();
    }

    private <S, T> TranslatorFactory<S, T> getTranslatorFactory(Criteria<S, T> criteria) {
        return this.factories.get(new Pair<Class<S>, Class<T>>(criteria.getInputClass(), criteria.getOutputClass()));
    }
}

