/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.cluster;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.cluster.PointClassification;
import org.nd4j.linalg.factory.Nd4j;

public class ClusterSet
implements Serializable {
    private String distanceFunction;
    private List<Cluster> clusters;
    private Map<String, String> pointDistribution;

    public ClusterSet() {
        this(null);
    }

    public ClusterSet(String distanceFunction) {
        this.distanceFunction = distanceFunction;
        this.clusters = Collections.synchronizedList(new ArrayList());
        this.pointDistribution = Collections.synchronizedMap(new HashMap());
    }

    public Cluster addNewClusterWithCenter(Point center) {
        Cluster newCluster = new Cluster(center, this.distanceFunction);
        this.getClusters().add(newCluster);
        this.setPointLocation(center, newCluster);
        return newCluster;
    }

    public PointClassification classifyPoint(Point point) {
        return this.classifyPoint(point, true);
    }

    public void classifyPoints(List<Point> points) {
        this.classifyPoints(points, true);
    }

    public void classifyPoints(List<Point> points, boolean moveClusterCenter) {
        for (Point point : points) {
            this.classifyPoint(point, moveClusterCenter);
        }
    }

    public PointClassification classifyPoint(Point point, boolean moveClusterCenter) {
        Pair<Cluster, Double> nearestCluster = this.nearestCluster(point);
        Cluster newCluster = nearestCluster.getFirst();
        boolean locationChange = this.isPointLocationChange(point, newCluster);
        this.addPointToCluster(point, newCluster, moveClusterCenter);
        return new PointClassification(nearestCluster.getFirst(), nearestCluster.getSecond(), locationChange);
    }

    private boolean isPointLocationChange(Point point, Cluster newCluster) {
        if (!this.getPointDistribution().containsKey(point.getId())) {
            return true;
        }
        return !this.getPointDistribution().get(point.getId()).equals(newCluster.getId());
    }

    private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) {
        cluster.addPoint(point, moveClusterCenter);
        this.setPointLocation(point, cluster);
    }

    private void setPointLocation(Point point, Cluster cluster) {
        this.pointDistribution.put(point.getId(), cluster.getId());
    }

    public Pair<Cluster, Double> nearestCluster(Point point) {
        Cluster nearestCluster = null;
        double minDistance = 3.4028234663852886E38;
        for (Cluster cluster : this.getClusters()) {
            double currentDistance = cluster.getDistanceToCenter(point);
            if (!(currentDistance < minDistance)) continue;
            minDistance = currentDistance;
            nearestCluster = cluster;
        }
        return new Pair<Object, Double>(nearestCluster, minDistance);
    }

    public double getDistance(Point m1, Point m2) {
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(this.distanceFunction, m1.getArray(), m2.getArray())).currentResult().doubleValue();
    }

    public double getDistanceFromNearestCluster(Point point) {
        return this.nearestCluster(point).getSecond();
    }

    public String getClusterCenterId(String clusterId) {
        Point clusterCenter = this.getClusterCenter(clusterId);
        return clusterCenter == null ? null : clusterCenter.getId();
    }

    public Point getClusterCenter(String clusterId) {
        Cluster cluster = this.getCluster(clusterId);
        return cluster == null ? null : cluster.getCenter();
    }

    public Cluster getCluster(String id) {
        int j = this.clusters.size();
        for (int i = 0; i < j; ++i) {
            if (!id.equals(this.clusters.get(i).getId())) continue;
            return this.clusters.get(i);
        }
        return null;
    }

    public int getClusterCount() {
        return this.getClusters() == null ? 0 : this.getClusters().size();
    }

    public void removePoints() {
        for (Cluster cluster : this.getClusters()) {
            cluster.removePoints();
        }
    }

    public List<Cluster> getMostPopulatedClusters(int count) {
        ArrayList<Cluster> mostPopulated = new ArrayList<Cluster>(this.clusters);
        Collections.sort(mostPopulated, new Comparator<Cluster>(){

            @Override
            public int compare(Cluster o1, Cluster o2) {
                return new Integer(o1.getPoints().size()).compareTo(new Integer(o2.getPoints().size()));
            }
        });
        return mostPopulated.subList(0, count);
    }

    public List<Cluster> removeEmptyClusters() {
        ArrayList<Cluster> emptyClusters = new ArrayList<Cluster>();
        for (Cluster cluster : this.clusters) {
            if (!cluster.isEmpty()) continue;
            emptyClusters.add(cluster);
        }
        this.clusters.removeAll(emptyClusters);
        return emptyClusters;
    }

    public List<Cluster> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<Cluster> clusters) {
        this.clusters = clusters;
    }

    public String getAccumulation() {
        return this.distanceFunction;
    }

    public void setAccumulation(String distanceFunction) {
        this.distanceFunction = distanceFunction;
    }

    public Map<String, String> getPointDistribution() {
        return this.pointDistribution;
    }

    public void setPointDistribution(Map<String, String> pointDistribution) {
        this.pointDistribution = pointDistribution;
    }
}

