/*
 * Decompiled with CFR 0.152.
 */
package org.mlflow.sagemaker;

import com.fasterxml.jackson.core.JsonProcessingException;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import ml.combust.mleap.runtime.MleapContext;
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
import ml.combust.mleap.runtime.frame.FrameBuilder;
import ml.combust.mleap.runtime.frame.Row;
import ml.combust.mleap.runtime.frame.Transformer;
import ml.combust.mleap.runtime.javadsl.BundleBuilder;
import ml.combust.mleap.runtime.javadsl.ContextBuilder;
import org.mlflow.sagemaker.InvalidSchemaException;
import org.mlflow.sagemaker.LeapFrameSchema;
import org.mlflow.sagemaker.PandasRecordOrientedDataFrame;
import org.mlflow.sagemaker.Predictor;
import org.mlflow.sagemaker.PredictorDataWrapper;
import org.mlflow.sagemaker.PredictorEvaluationException;
import org.mlflow.sagemaker.PredictorLoadingException;
import org.mlflow.utils.SerializationUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class MLeapPredictor
extends Predictor {
    private final Transformer pipelineTransformer;
    private final LeapFrameSchema inputSchema;
    private static final String PREDICTION_COLUMN_NAME = "prediction";
    private static final Logger logger = LoggerFactory.getLogger(MLeapPredictor.class);

    public MLeapPredictor(String modelDataPath, String inputSchemaPath) {
        MleapContext mleapContext = new ContextBuilder().createMleapContext();
        BundleBuilder bundleBuilder = new BundleBuilder();
        this.pipelineTransformer = (Transformer)bundleBuilder.load(new File(modelDataPath), mleapContext).root();
        try {
            this.inputSchema = LeapFrameSchema.fromPath(inputSchemaPath);
        }
        catch (IOException e) {
            logger.error("Could not read the model input schema from the specified path", (Throwable)e);
            throw new PredictorLoadingException(String.format("Failed to load model input schema from specified path: %s", inputSchemaPath));
        }
    }

    @Override
    protected PredictorDataWrapper predict(PredictorDataWrapper input) throws PredictorEvaluationException {
        PandasRecordOrientedDataFrame pandasFrame = null;
        try {
            pandasFrame = PandasRecordOrientedDataFrame.fromJson(input.toJson());
        }
        catch (IOException e) {
            logger.error("Encountered a JSON conversion error during conversion of Pandas dataframe to LeapFrame.", (Throwable)e);
            throw new PredictorEvaluationException("Failed to transform input into a JSON representation of an MLeap dataframe. Please ensure that the input is a JSON-serialized Pandas Dataframe with the `record` orientation.", e);
        }
        DefaultLeapFrame leapFrame = null;
        try {
            leapFrame = pandasFrame.toLeapFrame(this.inputSchema);
        }
        catch (InvalidSchemaException e) {
            throw new PredictorEvaluationException("Encountered a schema mismatch when converting the input dataframe to a LeapFrame.");
        }
        catch (Exception e) {
            logger.error("Encountered an unknown error during conversion of Pandas dataframe to LeapFrame.", (Throwable)e);
            throw new PredictorEvaluationException("An unknown error occurred while converting the input dataframe to a LeapFrame. Original exception text: %s", e);
        }
        Seq predictionColumnSelectionArgs = ((Iterator)JavaConverters.asScalaIteratorConverter(Arrays.asList(PREDICTION_COLUMN_NAME).iterator()).asScala()).toSeq();
        DefaultLeapFrame predictionsFrame = (DefaultLeapFrame)((DefaultLeapFrame)this.pipelineTransformer.transform((FrameBuilder)leapFrame).get()).select(predictionColumnSelectionArgs).get();
        Seq predictionRows = predictionsFrame.collect();
        java.lang.Iterable predictionRowsIterable = (java.lang.Iterable)JavaConverters.asJavaIterableConverter((Iterable)predictionRows).asJava();
        ArrayList<Object> predictions = new ArrayList<Object>();
        for (Row row : predictionRowsIterable) {
            predictions.add(row.getRaw(0));
        }
        String predictionsJson = null;
        try {
            predictionsJson = SerializationUtils.toJson(predictions);
        }
        catch (JsonProcessingException e) {
            logger.error("Encountered an error while serializing the output dataframe.", (Throwable)e);
            throw new PredictorEvaluationException("Failed to serialize prediction results as a JSON list!");
        }
        return new PredictorDataWrapper(predictionsJson, PredictorDataWrapper.ContentType.Json);
    }

    public Transformer getPipeline() {
        return this.pipelineTransformer;
    }
}

