/*
 * Decompiled with CFR 0.152.
 */
package org.mlflow.artifacts;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.mlflow.api.proto.Service;
import org.mlflow.artifacts.ArtifactRepository;
import org.mlflow.tracking.MlflowClientException;
import org.mlflow.tracking.creds.MlflowHostCreds;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;
import org.mlflow_project.apachecommons.io.IOUtils;
import org.mlflow_project.google.common.annotations.VisibleForTesting;
import org.mlflow_project.google.common.collect.Lists;
import org.mlflow_project.google.gson.Gson;
import org.mlflow_project.google.gson.reflect.TypeToken;
import org.mlflow_project.google.protobuf.InvalidProtocolBufferException;
import org.mlflow_project.google.protobuf.Message;
import org.mlflow_project.google.protobuf.util.JsonFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CliBasedArtifactRepository
implements ArtifactRepository {
    private static final Logger logger = LoggerFactory.getLogger(CliBasedArtifactRepository.class);
    private static final AtomicBoolean mlflowSuccessfullyLoaded = new AtomicBoolean(false);
    private final String PYTHON_EXECUTABLE = Optional.ofNullable(System.getenv("MLFLOW_PYTHON_EXECUTABLE")).orElse("python");
    private final String artifactBaseDir;
    private final String runId;
    private final MlflowHostCredsProvider hostCredsProvider;

    public CliBasedArtifactRepository(String artifactBaseDir, String runId, MlflowHostCredsProvider hostCredsProvider) {
        this.artifactBaseDir = artifactBaseDir;
        this.runId = runId;
        this.hostCredsProvider = hostCredsProvider;
    }

    @Override
    public void logArtifact(File localFile, String artifactPath) {
        this.checkMlflowAccessible();
        if (!localFile.exists()) {
            throw new MlflowClientException("Local file does not exist: " + localFile);
        }
        if (localFile.isDirectory()) {
            throw new MlflowClientException("Local path points to a directory. Use logArtifacts instead: " + localFile);
        }
        ArrayList<String> baseCommand = Lists.newArrayList("artifacts", "log-artifact", "--local-file", localFile.toString());
        List<String> command = this.appendRunIdArtifactPath(baseCommand, this.runId, artifactPath);
        String tag = "log file " + localFile + " to " + this.getTargetIdentifier(artifactPath);
        this.forkMlflowProcess(command, tag);
    }

    @Override
    public void logArtifact(File localFile) {
        this.logArtifact(localFile, null);
    }

    @Override
    public void logArtifacts(File localDir, String artifactPath) {
        this.checkMlflowAccessible();
        if (!localDir.exists()) {
            throw new MlflowClientException("Local file does not exist: " + localDir);
        }
        if (localDir.isFile()) {
            throw new MlflowClientException("Local path points to a file. Use logArtifact instead: " + localDir);
        }
        ArrayList<String> baseCommand = Lists.newArrayList("artifacts", "log-artifacts", "--local-dir", localDir.toString());
        List<String> command = this.appendRunIdArtifactPath(baseCommand, this.runId, artifactPath);
        String tag = "log dir " + localDir + " to " + this.getTargetIdentifier(artifactPath);
        this.forkMlflowProcess(command, tag);
    }

    @Override
    public void logArtifacts(File localDir) {
        this.logArtifacts(localDir, null);
    }

    @Override
    public File downloadArtifacts(String artifactPath) {
        this.checkMlflowAccessible();
        String tag = "download artifacts for " + this.getTargetIdentifier(artifactPath);
        List<String> command = this.appendRunIdArtifactPath(Lists.newArrayList("artifacts", "download"), this.runId, artifactPath);
        String localPath = this.forkMlflowProcess(command, tag).trim();
        return new File(localPath);
    }

    @Override
    public File downloadArtifacts() {
        return this.downloadArtifacts(null);
    }

    @Override
    public List<Service.FileInfo> listArtifacts(String artifactPath) {
        this.checkMlflowAccessible();
        String tag = "list artifacts in " + this.getTargetIdentifier(artifactPath);
        List<String> command = this.appendRunIdArtifactPath(Lists.newArrayList("artifacts", "list"), this.runId, artifactPath);
        String jsonOutput = this.forkMlflowProcess(command, tag);
        return this.parseFileInfos(jsonOutput);
    }

    @Override
    public List<Service.FileInfo> listArtifacts() {
        return this.listArtifacts(null);
    }

    private List<Service.FileInfo> parseFileInfos(String json) {
        Gson gson = new Gson();
        Type type = new TypeToken<List<Map<String, Object>>>(){}.getType();
        List listOfDicts = (List)gson.fromJson(json, type);
        ArrayList<Service.FileInfo> fileInfos = new ArrayList<Service.FileInfo>();
        for (Map dict : listOfDicts) {
            String fileInfoJson = gson.toJson(dict);
            try {
                Service.FileInfo.Builder builder = Service.FileInfo.newBuilder();
                JsonFormat.parser().merge(fileInfoJson, (Message.Builder)builder);
                fileInfos.add(builder.build());
            }
            catch (InvalidProtocolBufferException e) {
                throw new MlflowClientException("Failed to deserialize JSON into FileInfo: " + json, e);
            }
        }
        return fileInfos;
    }

    private void checkMlflowAccessible() {
        if (mlflowSuccessfullyLoaded.get()) {
            return;
        }
        try {
            String tag = "get mlflow version";
            this.forkMlflowProcess(Lists.newArrayList("--help"), tag);
            logger.info("Found local mlflow executable");
            mlflowSuccessfullyLoaded.set(true);
        }
        catch (MlflowClientException e) {
            String errorMessage = String.format("Failed to exec '%s -m mlflow.cli', needed to access artifacts within the non-Java-native artifact store at '%s'. Please make sure mlflow is available on your local system path (e.g., from 'pip install mlflow')", this.PYTHON_EXECUTABLE, this.artifactBaseDir);
            throw new MlflowClientException(errorMessage, e);
        }
    }

    private String forkMlflowProcess(List<String> mlflowCommand, String tag) {
        String stdout;
        Process process = null;
        try {
            MlflowHostCreds hostCreds = this.hostCredsProvider.getHostCreds();
            ArrayList<String> fullCommand = Lists.newArrayList(this.PYTHON_EXECUTABLE, "-m", "mlflow.cli");
            fullCommand.addAll(mlflowCommand);
            ProcessBuilder pb = new ProcessBuilder(fullCommand);
            this.setProcessEnvironment(pb.environment(), hostCreds);
            process = pb.start();
            stdout = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
            int exitValue = process.waitFor();
            if (exitValue != 0) {
                throw new MlflowClientException("Failed to " + tag + ". Error: " + this.getErrorBestEffort(process));
            }
        }
        catch (IOException | InterruptedException e) {
            throw new MlflowClientException("Failed to fork mlflow process to " + tag + ". Process stderr: " + this.getErrorBestEffort(process), e);
        }
        return stdout;
    }

    @VisibleForTesting
    void setProcessEnvironment(Map<String, String> environment, MlflowHostCreds hostCreds) {
        environment.put("MLFLOW_TRACKING_URI", hostCreds.getHost());
        if (hostCreds.getUsername() != null) {
            environment.put("MLFLOW_TRACKING_USERNAME", hostCreds.getUsername());
        }
        if (hostCreds.getPassword() != null) {
            environment.put("MLFLOW_TRACKING_PASSWORD", hostCreds.getPassword());
        }
        if (hostCreds.getToken() != null) {
            environment.put("MLFLOW_TRACKING_TOKEN", hostCreds.getToken());
        }
        if (hostCreds.shouldIgnoreTlsVerification()) {
            environment.put("MLFLOW_TRACKING_INSECURE_TLS", "true");
        }
    }

    private String getErrorBestEffort(Process process) {
        if (process == null) {
            return "<process not started>";
        }
        try {
            return IOUtils.toString(process.getErrorStream(), StandardCharsets.UTF_8);
        }
        catch (IOException e) {
            return "<error unknown>";
        }
    }

    private List<String> appendRunIdArtifactPath(List<String> baseCommand, String runId, String artifactPath) {
        baseCommand.add("--run-id");
        baseCommand.add(runId);
        if (artifactPath != null) {
            baseCommand.add("--artifact-path");
            baseCommand.add(artifactPath);
        }
        return baseCommand;
    }

    private String getTargetIdentifier(String artifactPath) {
        String identifier = "runId=" + this.runId;
        if (artifactPath != null) {
            return identifier + ", artifactPath=" + artifactPath;
        }
        return identifier;
    }
}

