/*
 * Decompiled with CFR 0.152.
 */
package org.mlflow.tracking;

import java.io.File;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.mlflow.api.proto.Service;
import org.mlflow.artifacts.ArtifactRepository;
import org.mlflow.artifacts.ArtifactRepositoryFactory;
import org.mlflow.tracking.MlflowClientException;
import org.mlflow.tracking.MlflowHttpCaller;
import org.mlflow.tracking.MlflowProtobufMapper;
import org.mlflow.tracking.RunsPage;
import org.mlflow.tracking.creds.BasicMlflowHostCreds;
import org.mlflow.tracking.creds.DatabricksConfigHostCredsProvider;
import org.mlflow.tracking.creds.DatabricksDynamicHostCredsProvider;
import org.mlflow.tracking.creds.HostCredsProviderChain;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;
import org.mlflow_project.apachehttp.client.utils.URIBuilder;

public class MlflowClient {
    protected static final String DEFAULT_EXPERIMENT_ID = "0";
    private final MlflowProtobufMapper mapper = new MlflowProtobufMapper();
    private final ArtifactRepositoryFactory artifactRepositoryFactory;
    private final MlflowHttpCaller httpCaller;
    private final MlflowHostCredsProvider hostCredsProvider;

    public MlflowClient() {
        this(MlflowClient.getDefaultTrackingUri());
    }

    public MlflowClient(String trackingUri) {
        this(MlflowClient.getHostCredsProviderFromTrackingUri(trackingUri));
    }

    public MlflowClient(MlflowHostCredsProvider hostCredsProvider) {
        this.hostCredsProvider = hostCredsProvider;
        this.httpCaller = new MlflowHttpCaller(hostCredsProvider);
        this.artifactRepositoryFactory = new ArtifactRepositoryFactory(hostCredsProvider);
    }

    public Service.Run getRun(String runId) {
        URIBuilder builder = this.newURIBuilder("runs/get").setParameter("run_uuid", runId).setParameter("run_id", runId);
        return this.mapper.toGetRunResponse(this.httpCaller.get(builder.toString())).getRun();
    }

    public List<Service.Metric> getMetricHistory(String runId, String key) {
        URIBuilder builder = this.newURIBuilder("metrics/get-history").setParameter("run_uuid", runId).setParameter("run_id", runId).setParameter("metric_key", key);
        return this.mapper.toGetMetricHistoryResponse(this.httpCaller.get(builder.toString())).getMetricsList();
    }

    public Service.RunInfo createRun() {
        return this.createRun(DEFAULT_EXPERIMENT_ID);
    }

    public Service.RunInfo createRun(String experimentId) {
        Service.CreateRun.Builder request = Service.CreateRun.newBuilder();
        request.setExperimentId(experimentId);
        request.setStartTime(System.currentTimeMillis());
        String username = System.getProperty("user.name");
        if (username != null) {
            request.setUserId(System.getProperty("user.name"));
        }
        return this.createRun(request.build());
    }

    public Service.RunInfo createRun(Service.CreateRun request) {
        String ijson = this.mapper.toJson(request);
        String ojson = this.sendPost("runs/create", ijson);
        return this.mapper.toCreateRunResponse(ojson).getRun().getInfo();
    }

    public List<Service.RunInfo> listRunInfos(String experimentId) {
        ArrayList<String> experimentIds = new ArrayList<String>();
        experimentIds.add(experimentId);
        return this.searchRuns(experimentIds, null);
    }

    public List<Service.RunInfo> searchRuns(List<String> experimentIds, String searchFilter) {
        return this.searchRuns(experimentIds, searchFilter, Service.ViewType.ACTIVE_ONLY, 1000).getItems().stream().map(Service.Run::getInfo).collect(Collectors.toList());
    }

    public List<Service.RunInfo> searchRuns(List<String> experimentIds, String searchFilter, Service.ViewType runViewType) {
        return this.searchRuns(experimentIds, searchFilter, runViewType, 1000).getItems().stream().map(Service.Run::getInfo).collect(Collectors.toList());
    }

    public RunsPage searchRuns(List<String> experimentIds, String searchFilter, Service.ViewType runViewType, int maxResults) {
        return this.searchRuns(experimentIds, searchFilter, runViewType, maxResults, new ArrayList<String>(), null);
    }

    public RunsPage searchRuns(List<String> experimentIds, String searchFilter, Service.ViewType runViewType, int maxResults, List<String> orderBy) {
        return this.searchRuns(experimentIds, searchFilter, runViewType, maxResults, orderBy, null);
    }

    public RunsPage searchRuns(List<String> experimentIds, String searchFilter, Service.ViewType runViewType, int maxResults, List<String> orderBy, String pageToken) {
        Service.SearchRuns.Builder builder = Service.SearchRuns.newBuilder().addAllExperimentIds(experimentIds).addAllOrderBy(orderBy).setMaxResults(maxResults);
        if (searchFilter != null) {
            builder.setFilter(searchFilter);
        }
        if (runViewType != null) {
            builder.setRunViewType(runViewType);
        }
        if (pageToken != null) {
            builder.setPageToken(pageToken);
        }
        Service.SearchRuns request = builder.build();
        String ijson = this.mapper.toJson(request);
        String ojson = this.sendPost("runs/search", ijson);
        Service.SearchRuns.Response response = this.mapper.toSearchRunsResponse(ojson);
        return new RunsPage(response.getRunsList(), response.getNextPageToken(), experimentIds, searchFilter, runViewType, maxResults, orderBy, this);
    }

    public List<Service.Experiment> listExperiments() {
        return this.mapper.toListExperimentsResponse(this.httpCaller.get("experiments/list")).getExperimentsList();
    }

    public Service.GetExperiment.Response getExperiment(String experimentId) {
        URIBuilder builder = this.newURIBuilder("experiments/get").setParameter("experiment_id", experimentId);
        return this.mapper.toGetExperimentResponse(this.httpCaller.get(builder.toString()));
    }

    public Optional<Service.Experiment> getExperimentByName(String experimentName) {
        return this.listExperiments().stream().filter(e -> e.getName().equals(experimentName)).findFirst();
    }

    public String createExperiment(String experimentName) {
        String ijson = this.mapper.makeCreateExperimentRequest(experimentName);
        String ojson = this.httpCaller.post("experiments/create", ijson);
        return this.mapper.toCreateExperimentResponse(ojson).getExperimentId();
    }

    public void deleteExperiment(String experimentId) {
        String ijson = this.mapper.makeDeleteExperimentRequest(experimentId);
        this.httpCaller.post("experiments/delete", ijson);
    }

    public void restoreExperiment(String experimentId) {
        String ijson = this.mapper.makeRestoreExperimentRequest(experimentId);
        this.httpCaller.post("experiments/restore", ijson);
    }

    public void renameExperiment(String experimentId, String newName) {
        String ijson = this.mapper.makeUpdateExperimentRequest(experimentId, newName);
        this.httpCaller.post("experiments/update", ijson);
    }

    public void deleteRun(String runId) {
        String ijson = this.mapper.makeDeleteRun(runId);
        this.httpCaller.post("runs/delete", ijson);
    }

    public void restoreRun(String runId) {
        String ijson = this.mapper.makeRestoreRun(runId);
        this.httpCaller.post("runs/restore", ijson);
    }

    public void logParam(String runId, String key, String value) {
        this.sendPost("runs/log-parameter", this.mapper.makeLogParam(runId, key, value));
    }

    public void logMetric(String runId, String key, double value) {
        this.logMetric(runId, key, value, System.currentTimeMillis(), 0L);
    }

    public void logMetric(String runId, String key, double value, long timestamp, long step) {
        this.sendPost("runs/log-metric", this.mapper.makeLogMetric(runId, key, value, timestamp, step));
    }

    public void setExperimentTag(String experimentId, String key, String value) {
        this.sendPost("experiments/set-experiment-tag", this.mapper.makeSetExperimentTag(experimentId, key, value));
    }

    public void setTag(String runId, String key, String value) {
        this.sendPost("runs/set-tag", this.mapper.makeSetTag(runId, key, value));
    }

    public void deleteTag(String runId, String key) {
        this.sendPost("runs/delete-tag", this.mapper.makeDeleteTag(runId, key));
    }

    public void logBatch(String runId, Iterable<Service.Metric> metrics, Iterable<Service.Param> params, Iterable<Service.RunTag> tags) {
        this.sendPost("runs/log-batch", this.mapper.makeLogBatch(runId, metrics, params, tags));
    }

    public void setTerminated(String runId) {
        this.setTerminated(runId, Service.RunStatus.FINISHED);
    }

    public void setTerminated(String runId, Service.RunStatus status) {
        this.setTerminated(runId, status, System.currentTimeMillis());
    }

    public void setTerminated(String runId, Service.RunStatus status, long endTime) {
        this.sendPost("runs/update", this.mapper.makeUpdateRun(runId, status, endTime));
    }

    public String sendGet(String path) {
        return this.httpCaller.get(path);
    }

    public String sendPost(String path, String json) {
        return this.httpCaller.post(path, json);
    }

    MlflowHostCredsProvider getInternalHostCredsProvider() {
        return this.hostCredsProvider;
    }

    private URIBuilder newURIBuilder(String base) {
        try {
            return new URIBuilder(base);
        }
        catch (URISyntaxException e) {
            throw new MlflowClientException("Failed to construct URI for " + base, e);
        }
    }

    private static String getDefaultTrackingUri() {
        String defaultTrackingUri = System.getenv("MLFLOW_TRACKING_URI");
        if (defaultTrackingUri == null) {
            throw new IllegalStateException("Default client requires MLFLOW_TRACKING_URI is set. Use fromTrackingUri() instead.");
        }
        return defaultTrackingUri;
    }

    private static MlflowHostCredsProvider getHostCredsProviderFromTrackingUri(String trackingUri) {
        MlflowHostCredsProvider provider;
        URI uri = URI.create(trackingUri);
        if ("http".equals(uri.getScheme()) || "https".equals(uri.getScheme())) {
            provider = new BasicMlflowHostCreds(trackingUri);
        } else if (trackingUri.equals("databricks")) {
            DatabricksConfigHostCredsProvider profileProvider = new DatabricksConfigHostCredsProvider();
            DatabricksDynamicHostCredsProvider dynamicProvider = DatabricksDynamicHostCredsProvider.createIfAvailable();
            provider = dynamicProvider != null ? new HostCredsProviderChain(dynamicProvider, profileProvider) : profileProvider;
        } else if ("databricks".equals(uri.getScheme())) {
            provider = new DatabricksConfigHostCredsProvider(uri.getHost());
        } else {
            if (uri.getScheme() == null || "file".equals(uri.getScheme())) {
                throw new IllegalArgumentException("Java Client currently does not support local tracking URIs. Please point to a Tracking Server.");
            }
            throw new IllegalArgumentException("Invalid tracking server uri: " + trackingUri);
        }
        return provider;
    }

    public void logArtifact(String runId, File localFile) {
        this.getArtifactRepository(runId).logArtifact(localFile);
    }

    public void logArtifact(String runId, File localFile, String artifactPath) {
        this.getArtifactRepository(runId).logArtifact(localFile, artifactPath);
    }

    public void logArtifacts(String runId, File localDir) {
        this.getArtifactRepository(runId).logArtifacts(localDir);
    }

    public void logArtifacts(String runId, File localDir, String artifactPath) {
        this.getArtifactRepository(runId).logArtifacts(localDir, artifactPath);
    }

    public List<Service.FileInfo> listArtifacts(String runId) {
        return this.getArtifactRepository(runId).listArtifacts();
    }

    public List<Service.FileInfo> listArtifacts(String runId, String artifactPath) {
        return this.getArtifactRepository(runId).listArtifacts(artifactPath);
    }

    public File downloadArtifacts(String runId) {
        return this.getArtifactRepository(runId).downloadArtifacts();
    }

    public File downloadArtifacts(String runId, String artifactPath) {
        return this.getArtifactRepository(runId).downloadArtifacts(artifactPath);
    }

    private ArtifactRepository getArtifactRepository(String runId) {
        URI baseArtifactUri = URI.create(this.getRun(runId).getInfo().getArtifactUri());
        return this.artifactRepositoryFactory.getArtifactRepository(baseArtifactUri, runId);
    }
}

