/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.cluster;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.algorithm.optimisation.ClusteringOptimizationType;
import org.deeplearning4j.clustering.algorithm.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.cluster.PointClassification;
import org.deeplearning4j.clustering.cluster.info.ClusterInfo;
import org.deeplearning4j.clustering.cluster.info.ClusterSetInfo;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.MultiThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class ClusterUtils {
    public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points, ExecutorService executorService) {
        final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        for (final Point point : points) {
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    try {
                        PointClassification result = ClusterUtils.classifyPoint(clusterSet, point);
                        if (result.isNewLocation()) {
                            clusterSetInfo.getPointLocationChange().incrementAndGet();
                        }
                        clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter().put(point.getId(), result.getDistanceFromCenter());
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
        return clusterSetInfo;
    }

    public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
        return clusterSet.classifyPoint(point, false);
    }

    public static void refreshClustersCenters(ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, ExecutorService executorService) {
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        int j = clusterSet.getClusterCount();
        for (int i = 0; i < j; ++i) {
            final Cluster cluster = clusterSet.getClusters().get(i);
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    try {
                        ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
                        ClusterUtils.refreshClusterCenter(cluster, clusterInfo);
                        ClusterUtils.deriveClusterInfoDistanceStatistics(clusterInfo);
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
    }

    public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
        int pointsCount = cluster.getPoints().size();
        if (pointsCount == 0) {
            return;
        }
        Point center = new Point(Nd4j.create((int)cluster.getPoints().get(0).getArray().length()));
        for (Point point : cluster.getPoints()) {
            center.getArray().addi(point.getArray());
        }
        center.getArray().divi((Number)pointsCount);
        cluster.setCenter(center);
    }

    public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
        int pointCount = info.getPointDistancesFromCenter().size();
        if (pointCount == 0) {
            return;
        }
        double[] distances = ArrayUtils.toPrimitive((Double[])info.getPointDistancesFromCenter().values().toArray(new Double[0]));
        double max = MathUtils.max(distances);
        double total = MathUtils.sum(distances);
        info.setMaxPointDistanceFromCenter(max);
        info.setTotalPointDistanceFromCenter(total);
        info.setAveragePointDistanceFromCenter(total / (double)pointCount);
        info.setPointDistanceFromCenterVariance(MathUtils.variance(distances));
    }

    public static INDArray computeSquareDistancesFromNearestCluster(ClusterSet clusterSet, final List<Point> points, INDArray previousDxs, ExecutorService executorService) {
        int pointsCount = points.size();
        final INDArray dxs = Nd4j.create((int)pointsCount);
        final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        int i = 0;
        while (i < pointsCount) {
            final int i2 = i++;
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    Point point = (Point)points.get(i2);
                    dxs.putScalar(i2, Math.pow(newCluster.getDistanceToCenter(point), 2.0));
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
        for (i = 0; i < pointsCount; ++i) {
            double previousMinDistance = previousDxs.getDouble(i);
            if (!(dxs.getDouble(i) > previousMinDistance)) continue;
            dxs.putScalar(i, previousMinDistance);
        }
        return dxs;
    }

    public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
        ExecutorService executor = MultiThreadUtils.newExecutorService();
        ClusterSetInfo info = ClusterUtils.computeClusterSetInfo(clusterSet, executor);
        executor.shutdownNow();
        return info;
    }

    public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
        int i;
        final ClusterSetInfo info = new ClusterSetInfo(true);
        int clusterCount = clusterSet.getClusterCount();
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        for (i = 0; i < clusterCount; ++i) {
            final Cluster cluster = clusterSet.getClusters().get(i);
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    info.getClustersInfos().put(cluster.getId(), ClusterUtils.computeClusterInfos(cluster, clusterSet.getAccumulation()));
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
        tasks = new ArrayList();
        for (i = 0; i < clusterCount; ++i) {
            final int clusterIdx = i;
            final Cluster fromCluster = clusterSet.getClusters().get(i);
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    try {
                        int l = clusterSet.getClusterCount();
                        for (int k = clusterIdx + 1; k < l; ++k) {
                            Cluster toCluster = clusterSet.getClusters().get(k);
                            double distance = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(clusterSet.getAccumulation(), fromCluster.getCenter().getArray(), toCluster.getCenter().getArray())).currentResult().doubleValue();
                            info.getDistancesBetweenClustersCenters().put((Object)fromCluster.getId(), (Object)toCluster.getId(), (Object)distance);
                        }
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
        return info;
    }

    public static ClusterInfo computeClusterInfos(Cluster cluster, String distanceFunction) {
        ClusterInfo info = new ClusterInfo(true);
        int j = cluster.getPoints().size();
        for (int i = 0; i < j; ++i) {
            Point point = cluster.getPoints().get(i);
            double distance = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(distanceFunction, cluster.getCenter().getArray(), point.getArray())).currentResult().doubleValue();
            info.getPointDistancesFromCenter().put(point.getId(), distance);
            info.setTotalPointDistanceFromCenter(info.getTotalPointDistanceFromCenter() + distance);
        }
        if (cluster.getPoints().size() > 0) {
            info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / (double)cluster.getPoints().size());
        }
        return info;
    }

    public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, ExecutorService executor) {
        if (optimization.isClusteringOptimizationType(ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
            int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
            return splitCount > 0;
        }
        if (optimization.isClusteringOptimizationType(ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) {
            int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
            return splitCount > 0;
        }
        return false;
    }

    public static List<Cluster> getMostSpreadOutClusters(ClusterSet clusterSet, final ClusterSetInfo info, int count) {
        ArrayList<Cluster> clusters = new ArrayList<Cluster>(clusterSet.getClusters());
        Collections.sort(clusters, new Comparator<Cluster>(){

            @Override
            public int compare(Cluster o1, Cluster o2) {
                Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter();
                Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter();
                return -o1TotalDistance.compareTo(o2TotalDistance);
            }
        });
        return clusters.subList(0, count);
    }

    public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo info, double maximumAverageDistance) {
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        for (Cluster cluster : clusterSet.getClusters()) {
            ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
            if (clusterInfo == null || !(clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance)) continue;
            clusters.add(cluster);
        }
        return clusters;
    }

    public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo info, double maximumDistance) {
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        for (Cluster cluster : clusterSet.getClusters()) {
            ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
            if (clusterInfo == null || !(clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance)) continue;
            clusters.add(cluster);
        }
        return clusters;
    }

    public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, ExecutorService executorService) {
        List<Cluster> clustersToSplit = ClusterUtils.getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
        return clustersToSplit.size();
    }

    public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
        List<Cluster> clustersToSplit = ClusterUtils.getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, maxWithinClusterDistance);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
        return clustersToSplit.size();
    }

    public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
        List<Cluster> clustersToSplit = ClusterUtils.getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, maxWithinClusterDistance);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
        return clustersToSplit.size();
    }

    public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, ExecutorService executorService) {
        List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
    }

    public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, List<Cluster> clusters, final double maxDistance, ExecutorService executorService) {
        final Random random = new Random();
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        for (final Cluster cluster : clusters) {
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    try {
                        ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
                        List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance);
                        int rank = Math.min(fartherPoints.size(), 3);
                        String pointId = fartherPoints.get(random.nextInt(rank));
                        Point point = cluster.removePoint(pointId);
                        clusterSet.addNewClusterWithCenter(point);
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
    }

    public static void splitClusters(final ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, List<Cluster> clusters, ExecutorService executorService) {
        final Random random = new Random();
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        for (final Cluster cluster : clusters) {
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size()));
                    clusterSet.addNewClusterWithCenter(point);
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
    }
}

