package org.deeplearning4j.nearestneighbor.server;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import java.io.File;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.Base64NDArrayBody;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import play.Mode;
import play.libs.Json;
import play.mvc.Controller;
import play.mvc.Results;
import play.routing.RoutingDsl;
import play.server.Server;

/* loaded from: input_file:org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.class */
public class NearestNeighborsServer {
    private static final Logger log = LoggerFactory.getLogger(NearestNeighborsServer.class);

    @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
    private String ndarrayPath = null;

    @Parameter(names = {"--labelsPath"}, arity = 1, required = false)
    private String labelsPath = null;

    @Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
    private int port = 9000;

    @Parameter(names = {"--similarityFunction"}, arity = 1)
    private String similarityFunction = "euclidean";

    @Parameter(names = {"--invert"}, arity = 1)
    private boolean invert = false;
    private Server server;

    public void runMain(String... strArr) throws Exception {
        JCommander jCommander = new JCommander(this);
        try {
            jCommander.parse(strArr);
        } catch (ParameterException e) {
            log.error("Error in NearestNeighboursServer parameters", e);
            StringBuilder sb = new StringBuilder();
            jCommander.usage(sb);
            log.error("Usage: {}", sb.toString());
            jCommander.usage();
            if (this.ndarrayPath == null) {
                log.error("Json path parameter is missing (null)");
            }
            try {
                Thread.sleep(500L);
            } catch (Exception e2) {
            }
            System.exit(1);
        }
        try {
            runHelper();
        } catch (Throwable th) {
            log.error("Error in NearestNeighboursServer run method", th);
        }
    }

    protected void runHelper() throws Exception {
        String[] split = this.ndarrayPath.split(",");
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < split.length; i3++) {
            DataBuffer readShapeFromDisk = BinarySerde.readShapeFromDisk(new File(split[i3]));
            log.info("Loading shape {} of {}; Shape: [{} x {}]", new Object[]{Integer.valueOf(i3 + 1), Integer.valueOf(split.length), Integer.valueOf(Shape.size(readShapeFromDisk, 0)), Integer.valueOf(Shape.size(readShapeFromDisk, 1))});
            if (Shape.rank(readShapeFromDisk) != 2) {
                throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
            }
            i += Shape.size(readShapeFromDisk, 0);
            if (i2 == 0) {
                i2 = Shape.size(readShapeFromDisk, 1);
            } else if (i2 != Shape.size(readShapeFromDisk, 1)) {
                throw new DL4JInvalidInputException("NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
            }
        }
        ArrayList arrayList = new ArrayList();
        if (this.labelsPath != null) {
            for (String str : this.labelsPath.split(",")) {
                arrayList.addAll(FileUtils.readLines(new File(str), "utf-8"));
            }
        }
        if (!arrayList.isEmpty() && arrayList.size() != i) {
            throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", Integer.valueOf(i), Integer.valueOf(arrayList.size())));
        }
        INDArray createUninitialized = Nd4j.createUninitialized(i, i2);
        int i4 = 0;
        for (int i5 = 0; i5 < split.length; i5++) {
            log.info("Loading chunk {} of {}", Integer.valueOf(i5 + 1), Integer.valueOf(split.length));
            INDArray readFromDisk = BinarySerde.readFromDisk(new File(split[i5]));
            createUninitialized.get(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + readFromDisk.rows())}).assign(readFromDisk);
            i4 += readFromDisk.rows();
            System.gc();
        }
        VPTree vPTree = new VPTree(createUninitialized, this.similarityFunction, this.invert);
        RoutingDsl routingDsl = new RoutingDsl();
        routingDsl.POST("/knn").routeTo(FunctionUtil.function0(() -> {
            try {
                NearestNeighborRequest nearestNeighborRequest = (NearestNeighborRequest) Json.fromJson(Controller.request().body().asJson(), NearestNeighborRequest.class);
                return nearestNeighborRequest == null ? Results.badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))) : Results.ok(Json.toJson(NearestNeighborsResults.builder().results(NearestNeighbor.builder().points(createUninitialized).record(nearestNeighborRequest).tree(vPTree).build().search()).build()));
            } catch (Throwable th) {
                log.error("Error in POST /knn", th);
                th.printStackTrace();
                return Results.internalServerError(th.getMessage());
            }
        }));
        routingDsl.POST("/knnnew").routeTo(FunctionUtil.function0(() -> {
            ArrayList arrayList2;
            ArrayList arrayList3;
            try {
                Base64NDArrayBody base64NDArrayBody = (Base64NDArrayBody) Json.fromJson(Controller.request().body().asJson(), Base64NDArrayBody.class);
                if (base64NDArrayBody == null) {
                    return Results.badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
                }
                INDArray fromBase64 = Nd4jBase64.fromBase64(base64NDArrayBody.getNdarray());
                if (base64NDArrayBody.isForceFillK()) {
                    VPTreeFillSearch vPTreeFillSearch = new VPTreeFillSearch(vPTree, base64NDArrayBody.getK(), fromBase64);
                    vPTreeFillSearch.search();
                    arrayList2 = vPTreeFillSearch.getResults();
                    arrayList3 = vPTreeFillSearch.getDistances();
                } else {
                    arrayList2 = new ArrayList();
                    arrayList3 = new ArrayList();
                    vPTree.search(fromBase64, base64NDArrayBody.getK(), arrayList2, arrayList3);
                }
                if (arrayList2.size() != arrayList3.size()) {
                    return Results.internalServerError(String.format("results.size == %d != %d == distances.size", Integer.valueOf(arrayList2.size()), Integer.valueOf(arrayList3.size())));
                }
                ArrayList arrayList4 = new ArrayList();
                for (int i6 = 0; i6 < arrayList2.size(); i6++) {
                    if (arrayList.isEmpty()) {
                        arrayList4.add(new NearestNeighborsResult(((DataPoint) arrayList2.get(i6)).getIndex(), ((Double) arrayList3.get(i6)).doubleValue()));
                    } else {
                        arrayList4.add(new NearestNeighborsResult(((DataPoint) arrayList2.get(i6)).getIndex(), ((Double) arrayList3.get(i6)).doubleValue(), (String) arrayList.get(((DataPoint) arrayList2.get(i6)).getIndex())));
                    }
                }
                return Results.ok(Json.toJson(NearestNeighborsResults.builder().results(arrayList4).build()));
            } catch (Throwable th) {
                log.error("Error in POST /knnnew", th);
                th.printStackTrace();
                return Results.internalServerError(th.getMessage());
            }
        }));
        String property = System.getProperty("play.crypto.secret");
        if (property == null || "changeme".equals(property) || "".equals(property)) {
            byte[] bArr = new byte[1024];
            new Random().nextBytes(bArr);
            System.setProperty("play.crypto.secret", Base64.getEncoder().encodeToString(bArr));
        }
        this.server = Server.forRouter(routingDsl.build(), Mode.PROD, this.port);
    }

    public void stop() {
        if (this.server != null) {
            log.info("Attempting to stop server");
            this.server.stop();
        }
    }

    public static void main(String[] strArr) throws Exception {
        new NearestNeighborsServer().runMain(strArr);
    }
}
