/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.distancefunction.DistanceFunction;
import org.deeplearning4j.distancefunction.EuclideanDistance;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeansClustering
implements Serializable {
    private static final long serialVersionUID = 338231277453149972L;
    private static Logger log = LoggerFactory.getLogger(KMeansClustering.class);
    private List<Long> counts = null;
    private DoubleMatrix centroids;
    private List<DoubleMatrix> initFeatures = new ArrayList<DoubleMatrix>();
    private Class<DistanceFunction> clazz;
    private Integer nbCluster;

    public KMeansClustering(Integer nbCluster, Class<? extends DistanceFunction> clazz) {
        this.nbCluster = nbCluster;
    }

    public KMeansClustering(Integer nbCluster) {
        this(nbCluster, EuclideanDistance.class);
    }

    public Integer classify(DoubleMatrix features) {
        if (!this.isReady()) {
            throw new IllegalStateException("KMeans is not ready yet");
        }
        Integer nearestCentroidIndex = this.nearestCentroid(features);
        return nearestCentroidIndex;
    }

    public Integer update(DoubleMatrix features) {
        if (!this.isReady()) {
            this.initIfPossible(features);
            log.info("Initializing feature vector with length of " + features.length);
            return null;
        }
        Integer nearestCentroid = this.classify(features);
        this.counts.set(nearestCentroid, this.counts.get(nearestCentroid) + 1L);
        DoubleMatrix update = features.sub(this.centroids.getRow(nearestCentroid.intValue())).mul(1.0 / (double)this.counts.get(nearestCentroid).longValue());
        this.centroids.putRow(nearestCentroid.intValue(), this.centroids.getRow(nearestCentroid.intValue()).add(update));
        return nearestCentroid;
    }

    public DoubleMatrix distribution(DoubleMatrix features) {
        if (!this.isReady()) {
            throw new IllegalStateException("KMeans is not ready yet");
        }
        DoubleMatrix distribution = new DoubleMatrix(1, this.nbCluster.intValue());
        for (int i = 0; i < this.nbCluster; ++i) {
            DoubleMatrix currentCentroid = this.centroids.getRow(i);
            distribution.put(i, this.getDistance(currentCentroid, features));
        }
        return distribution;
    }

    private double getDistance(DoubleMatrix m1, DoubleMatrix m2) {
        DistanceFunction function = null;
        try {
            function = this.clazz.getConstructor(DoubleMatrix.class).newInstance(m1);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return (Double)function.apply(m2);
    }

    public DoubleMatrix getCentroids() {
        return this.centroids;
    }

    protected Integer nearestCentroid(DoubleMatrix features) {
        Integer nearestCentroidIndex = 0;
        Double minDistance = Double.MAX_VALUE;
        for (int i = 0; i < this.centroids.rows; ++i) {
            Double currentDistance;
            DoubleMatrix currentCentroid = this.centroids.getRow(i);
            if (currentCentroid == null || !((currentDistance = Double.valueOf(this.getDistance(currentCentroid, features))) < minDistance)) continue;
            minDistance = currentDistance;
            nearestCentroidIndex = i;
        }
        return nearestCentroidIndex;
    }

    protected boolean isReady() {
        boolean countsReady = this.counts != null;
        boolean centroidsReady = this.centroids != null;
        return countsReady && centroidsReady;
    }

    protected void initIfPossible(DoubleMatrix features) {
        this.initFeatures.add(features);
        log.info("Added feature vector of length " + features.length);
        if (this.initFeatures.size() >= 10 * this.nbCluster) {
            this.initCentroids();
        }
    }

    protected void initCentroids() {
        this.counts = new ArrayList<Long>(this.nbCluster);
        for (int i = 0; i < this.nbCluster; ++i) {
            this.counts.add(0L);
        }
        Random random = new Random();
        DoubleMatrix firstCentroid = this.initFeatures.remove(random.nextInt(this.initFeatures.size()));
        this.centroids = new DoubleMatrix(this.nbCluster.intValue(), firstCentroid.columns);
        this.centroids.putRow(0, firstCentroid);
        log.info("Added initial centroid");
        block1: for (int j = 1; j < this.nbCluster; ++j) {
            DoubleMatrix dxs = this.computeDxs();
            double r = random.nextDouble() * dxs.get(dxs.length - 1);
            for (int i = 0; i < dxs.length; ++i) {
                if (!(dxs.get(i) >= r)) continue;
                DoubleMatrix features = this.initFeatures.remove(i);
                this.centroids.putRow(j, features);
                continue block1;
            }
        }
        this.initFeatures.clear();
    }

    protected DoubleMatrix computeDxs() {
        DoubleMatrix dxs = new DoubleMatrix(this.initFeatures.size(), this.initFeatures.get((int)0).columns);
        int sum = 0;
        for (int i = 0; i < this.initFeatures.size(); ++i) {
            DoubleMatrix features = this.initFeatures.get(i);
            int nearestCentroidIndex = this.nearestCentroid(features);
            DoubleMatrix nearestCentroid = this.centroids.getRow(nearestCentroidIndex);
            sum = (int)((double)sum + MatrixFunctions.pow((double)this.getDistance(features, nearestCentroid), (double)2.0));
            dxs.put(i, (double)sum);
        }
        return dxs;
    }

    public void reset() {
        this.counts = null;
        this.centroids = null;
        this.initFeatures = new ArrayList<DoubleMatrix>();
    }
}

