/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.ExponentialFamilyMixture;
import smile.stat.distribution.GaussianDistribution;
import smile.stat.distribution.Mixture;

public class GaussianMixture
extends ExponentialFamilyMixture {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GaussianMixture.class);

    public GaussianMixture(Mixture.Component ... components) {
        this(0.0, 1, components);
    }

    private GaussianMixture(double L, int n, Mixture.Component ... components) {
        super(L, n, components);
        for (Mixture.Component component : components) {
            if (component.distribution instanceof GaussianDistribution) continue;
            throw new IllegalArgumentException("Component " + component + " is not of Gaussian distribution.");
        }
    }

    public static GaussianMixture fit(int k, double[] x) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of components in the mixture.");
        }
        double min = MathEx.min(x);
        double max = MathEx.max(x);
        double step = (max - min) / (double)(k + 1);
        Mixture.Component[] components = new Mixture.Component[k];
        for (int i = 0; i < k; ++i) {
            components[i] = new Mixture.Component(1.0 / (double)k, new GaussianDistribution(min += step, step));
        }
        ExponentialFamilyMixture model = GaussianMixture.fit(x, components);
        return new GaussianMixture(model.L, x.length, model.components);
    }

    public static GaussianMixture fit(double[] x) {
        if (x.length < 20) {
            throw new IllegalArgumentException("Too few samples.");
        }
        GaussianMixture mixture = new GaussianMixture(new Mixture.Component(1.0, GaussianDistribution.fit(x)));
        double bic = mixture.bic(x);
        logger.info(String.format("The BIC of %s = %.4f", mixture, bic));
        for (int k = 2; k < x.length / 10; ++k) {
            GaussianMixture model = GaussianMixture.fit(k, x);
            logger.info(String.format("The BIC of %s = %.4f", model, model.bic));
            if (model.bic <= bic) break;
            mixture = new GaussianMixture(model.L, x.length, model.components);
            bic = model.bic;
        }
        return mixture;
    }

    private static Mixture.Component[] split(Mixture.Component ... components) {
        int k = components.length;
        int index = -1;
        double maxSigma = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < k; ++i) {
            Mixture.Component c = components[i];
            if (!(c.distribution.sd() > maxSigma)) continue;
            maxSigma = c.distribution.sd();
            index = i;
        }
        Mixture.Component component = components[index];
        double priori = component.priori / 2.0;
        double delta = component.distribution.sd();
        double mu = component.distribution.mean();
        Mixture.Component[] mixture = new Mixture.Component[k + 1];
        System.arraycopy(components, 0, mixture, 0, k);
        mixture[index] = new Mixture.Component(priori, new GaussianDistribution(mu + delta / 2.0, delta));
        mixture[k] = new Mixture.Component(priori, new GaussianDistribution(mu - delta / 2.0, delta));
        return mixture;
    }
}

