package org.deeplearning4j.nearestneighbor.server;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nearestneighbor/server/NearestNeighbor.class */
public class NearestNeighbor {
    private NearestNeighborRequest record;
    private VPTree tree;
    private INDArray points;

    /* loaded from: input_file:org/deeplearning4j/nearestneighbor/server/NearestNeighbor$NearestNeighborBuilder.class */
    public static class NearestNeighborBuilder {
        private NearestNeighborRequest record;
        private VPTree tree;
        private INDArray points;

        NearestNeighborBuilder() {
        }

        public NearestNeighborBuilder record(NearestNeighborRequest nearestNeighborRequest) {
            this.record = nearestNeighborRequest;
            return this;
        }

        public NearestNeighborBuilder tree(VPTree vPTree) {
            this.tree = vPTree;
            return this;
        }

        public NearestNeighborBuilder points(INDArray iNDArray) {
            this.points = iNDArray;
            return this;
        }

        public NearestNeighbor build() {
            return new NearestNeighbor(this.record, this.tree, this.points);
        }

        public String toString() {
            return "NearestNeighbor.NearestNeighborBuilder(record=" + this.record + ", tree=" + this.tree + ", points=" + this.points + ")";
        }
    }

    public List<NearestNeighborsResult> search() {
        INDArray slice = this.points.slice(this.record.getInputIndex());
        ArrayList arrayList = new ArrayList();
        if (slice.isVector()) {
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            this.tree.search(slice, this.record.getK(), arrayList2, arrayList3);
            if (arrayList2.size() != arrayList3.size()) {
                throw new IllegalStateException(String.format("add.size == %d != %d == distances.size", Integer.valueOf(arrayList2.size()), Integer.valueOf(arrayList3.size())));
            }
            for (int i = 0; i < arrayList2.size(); i++) {
                arrayList.add(new NearestNeighborsResult(((DataPoint) arrayList2.get(i)).getIndex(), ((Double) arrayList3.get(i)).doubleValue()));
            }
        }
        return arrayList;
    }

    public static NearestNeighborBuilder builder() {
        return new NearestNeighborBuilder();
    }

    public NearestNeighbor(NearestNeighborRequest nearestNeighborRequest, VPTree vPTree, INDArray iNDArray) {
        this.record = nearestNeighborRequest;
        this.tree = vPTree;
        this.points = iNDArray;
    }
}
