/*
 * Decompiled with CFR 0.152.
 */
package org.mlflow.sagemaker;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import javax.servlet.Servlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.io.IOUtils;
import org.eclipse.jetty.server.ConnectionFactory;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.ThreadPool;
import org.mlflow.mleap.MLeapLoader;
import org.mlflow.models.Model;
import org.mlflow.sagemaker.Predictor;
import org.mlflow.sagemaker.PredictorDataWrapper;
import org.mlflow.sagemaker.PredictorEvaluationException;
import org.mlflow.sagemaker.PredictorLoadingException;
import org.mlflow.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScoringServer {
    public static final String RESPONSE_KEY_ERROR_MESSAGE = "Error";
    private static final String REQUEST_CONTENT_TYPE_JSON = "application/json";
    private static final String REQUEST_CONTENT_TYPE_CSV = "text/csv";
    static final String ENV_VAR_MINIMUM_SERVER_THREADS = "MLFLOW_SCORING_SERVER_MIN_THREADS";
    static final String ENV_VAR_MAXIMUM_SERVER_THREADS = "MLFLOW_SCORING_SERVER_MAX_THREADS";
    static final int DEFAULT_MINIMUM_SERVER_THREADS = 1;
    static final int DEFAULT_MAXIMUM_SERVER_THREADS = 16;
    private static final Logger logger = LoggerFactory.getLogger(ScoringServer.class);
    private final Server server;
    private final ServerConnector httpConnector;

    public ScoringServer(Predictor predictor) {
        int minThreads = EnvironmentUtils.getIntegerValue(ENV_VAR_MINIMUM_SERVER_THREADS, 1);
        int maxThreads = EnvironmentUtils.getIntegerValue(ENV_VAR_MAXIMUM_SERVER_THREADS, 16);
        this.server = new Server((ThreadPool)new QueuedThreadPool(maxThreads, minThreads));
        this.server.setStopAtShutdown(true);
        this.httpConnector = new ServerConnector(this.server, new ConnectionFactory[]{new HttpConnectionFactory()});
        this.server.addConnector((Connector)this.httpConnector);
        ServletContextHandler rootContextHandler = new ServletContextHandler(null, "/");
        rootContextHandler.addServlet(new ServletHolder((Servlet)new PingServlet()), "/ping");
        rootContextHandler.addServlet(new ServletHolder((Servlet)new VersionServlet()), "/version");
        rootContextHandler.addServlet(new ServletHolder((Servlet)new InvocationsServlet(predictor)), "/invocations");
        this.server.setHandler((Handler)rootContextHandler);
    }

    public ScoringServer(String modelPath) throws PredictorLoadingException {
        this(ScoringServer.loadPredictorFromPath(modelPath));
    }

    private static Predictor loadPredictorFromPath(String modelPath) throws PredictorLoadingException {
        try {
            Model config = Model.fromRootPath(modelPath);
            return new MLeapLoader().load(config);
        }
        catch (IOException e) {
            throw new PredictorLoadingException("Failed to load the configuration for the MLflow model at the specified path.", e);
        }
    }

    public void start() {
        this.start(0);
    }

    public void start(int portNumber) {
        if (this.isActive()) {
            int activePort = this.httpConnector.getLocalPort();
            throw new IllegalStateException(String.format("Attempted to start a server that is already active on port %d", activePort));
        }
        this.httpConnector.setPort(portNumber);
        try {
            this.server.start();
        }
        catch (Exception e) {
            throw new ServerStateChangeException(e);
        }
        logger.info(String.format("Started scoring server on port: %d", portNumber));
    }

    public void stop() {
        try {
            this.server.stop();
            this.server.join();
        }
        catch (Exception e) {
            throw new ServerStateChangeException(e);
        }
        logger.info("Stopped the scoring server successfully.");
    }

    public boolean isActive() {
        return this.server.isStarted();
    }

    public Optional<Integer> getPort() {
        int boundPort = this.httpConnector.getLocalPort();
        if (boundPort >= 0) {
            return Optional.of(boundPort);
        }
        return Optional.empty();
    }

    public static void main(String[] args) throws IOException, PredictorLoadingException {
        String modelPath = args[0];
        Optional<Object> portNum = Optional.empty();
        if (args.length > 1) {
            portNum = Optional.of(Integer.parseInt(args[1]));
        }
        ScoringServer server = new ScoringServer(modelPath);
        try {
            server.start(portNum.orElse(8080));
        }
        catch (ServerStateChangeException e) {
            logger.error("Encountered an error while starting the prediction server.", (Throwable)e);
        }
    }

    static class InvocationsServlet
    extends HttpServlet {
        private final Predictor predictor;

        InvocationsServlet(Predictor predictor) {
            this.predictor = predictor;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
            String requestContentType = request.getHeader("Content-type");
            String requestBody = IOUtils.toString((InputStream)request.getInputStream(), (Charset)StandardCharsets.UTF_8);
            String responseContent = null;
            try {
                responseContent = this.evaluateRequest(requestBody, requestContentType);
            }
            catch (PredictorEvaluationException e) {
                logger.error("Encountered a failure when evaluating the predictor.", (Throwable)e);
                response.setStatus(500);
                responseContent = this.getErrorResponseJson(e.getMessage());
            }
            catch (InvalidRequestTypeException e) {
                logger.info(String.format("Received a request with an unsupported content type: %s", requestContentType));
                response.setStatus(400);
                responseContent = this.getErrorResponseJson("Requests must have a content header of type `application/json` or `text/csv`");
            }
            catch (Exception e) {
                logger.error("An unknown error occurred while evaluating the prediction request.", (Throwable)e);
                response.setStatus(500);
                responseContent = this.getErrorResponseJson("An unknown error occurred while evaluating the model!");
            }
            finally {
                if (responseContent != null) {
                    response.getWriter().print(responseContent);
                    response.getWriter().close();
                }
            }
        }

        private String evaluateRequest(String requestContent, String requestContentType) throws PredictorEvaluationException, InvalidRequestTypeException {
            PredictorDataWrapper predictorInput = null;
            if (requestContentType.equals(ScoringServer.REQUEST_CONTENT_TYPE_JSON)) {
                predictorInput = new PredictorDataWrapper(requestContent, PredictorDataWrapper.ContentType.Json);
                PredictorDataWrapper result = this.predictor.predict(predictorInput);
                return result.toJson();
            }
            if (requestContentType.equals(ScoringServer.REQUEST_CONTENT_TYPE_CSV)) {
                predictorInput = new PredictorDataWrapper(requestContent, PredictorDataWrapper.ContentType.Csv);
                PredictorDataWrapper result = this.predictor.predict(predictorInput);
                return result.toCsv();
            }
            logger.error(String.format("Received a request with an unsupported content type: %s", requestContentType));
            throw new InvalidRequestTypeException("Invocations content must be of content type `application/json` or `text/csv`");
        }

        private String getErrorResponseJson(String errorMessage) {
            String response = String.format("{ \"%s\" : \"%s\" }", ScoringServer.RESPONSE_KEY_ERROR_MESSAGE, errorMessage);
            return response;
        }

        static class InvalidRequestTypeException
        extends Exception {
            InvalidRequestTypeException(String message) {
                super(message);
            }
        }
    }

    static class VersionServlet
    extends HttpServlet {
        VersionServlet() {
        }

        public void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
            response.setStatus(200);
            response.getWriter().print("2.4.0");
            response.getWriter().close();
        }
    }

    static class PingServlet
    extends HttpServlet {
        PingServlet() {
        }

        public void doGet(HttpServletRequest request, HttpServletResponse response) {
            response.setStatus(200);
        }
    }

    public static class ServerStateChangeException
    extends RuntimeException {
        ServerStateChangeException(Exception e) {
            super(e);
        }
    }
}

