/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.rng.examples.jmh.sampling.distribution;

import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
import org.apache.commons.rng.simple.RandomSource;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

@BenchmarkMode(value={Mode.AverageTime})
@OutputTimeUnit(value=TimeUnit.NANOSECONDS)
@Warmup(iterations=5, time=1, timeUnit=TimeUnit.SECONDS)
@Measurement(iterations=5, time=1, timeUnit=TimeUnit.SECONDS)
@State(value=Scope.Benchmark)
@Fork(value=1, jvmArgs={"-server", "-Xms128M", "-Xmx128M"})
public class EnumeratedDistributionSamplersPerformance {
    private int value;

    @Benchmark
    public int baselineInt() {
        return this.value;
    }

    @Benchmark
    public int baselineNextDouble(LocalRandomSources sources) {
        return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0;
    }

    @Benchmark
    public int sampleKnown(KnownDistributionSources sources) {
        return sources.getSampler().sample();
    }

    @Benchmark
    public int singleSampleKnown(KnownDistributionSources sources) {
        return sources.createSampler().sample();
    }

    @Benchmark
    public int sampleRandom(RandomDistributionSources sources) {
        return sources.getSampler().sample();
    }

    @Benchmark
    public int singleSampleRandom(RandomDistributionSources sources) {
        return sources.createSampler().sample();
    }

    static final class BinarySearchDiscreteSampler
    implements DiscreteSampler {
        private final UniformRandomProvider rng;
        private final double[] cumulativeProbabilities;

        BinarySearchDiscreteSampler(UniformRandomProvider rng, double[] probabilities) {
            if (probabilities == null || probabilities.length == 0) {
                throw new IllegalArgumentException("Probabilities must not be empty.");
            }
            int size = probabilities.length;
            this.cumulativeProbabilities = new double[size];
            double sumProb = 0.0;
            int count = 0;
            for (double prob : probabilities) {
                if (prob < 0.0 || Double.isInfinite(prob) || Double.isNaN(prob)) {
                    throw new IllegalArgumentException("Invalid probability: " + prob);
                }
                this.cumulativeProbabilities[count++] = sumProb += prob;
            }
            if (Double.isInfinite(sumProb) || sumProb <= 0.0) {
                throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
            }
            this.rng = rng;
            for (int i = 0; i < size; ++i) {
                double norm = this.cumulativeProbabilities[i] / sumProb;
                this.cumulativeProbabilities[i] = norm < 1.0 ? norm : 1.0;
            }
        }

        public int sample() {
            double u = this.rng.nextDouble();
            int lower = 0;
            int upper = this.cumulativeProbabilities.length - 1;
            while (lower < upper) {
                int mid = lower + upper >>> 1;
                double midVal = this.cumulativeProbabilities[mid];
                if (u > midVal) {
                    lower = mid + 1;
                    continue;
                }
                upper = mid;
            }
            return upper;
        }
    }

    @State(value=Scope.Benchmark)
    public static class RandomDistributionSources
    extends SamplerSources {
        @Param(value={"6", "96", "3072"})
        private int randomNonUniformSize;

        @Override
        protected double[] createProbabilities() {
            double[] probabilities = new double[this.randomNonUniformSize];
            ThreadLocalRandom rng = ThreadLocalRandom.current();
            for (int i = 0; i < probabilities.length; ++i) {
                probabilities[i] = rng.nextDouble();
            }
            return probabilities;
        }
    }

    @State(value=Scope.Benchmark)
    public static class KnownDistributionSources
    extends SamplerSources {
        private static final double CUMULATIVE_PROBABILITY_LIMIT = 0.999999999;
        @Param(value={"Binomial_N67_P0.7", "Geometric_P0.2", "4SidedLoadedDie", "Poisson_Mean3.14", "Poisson_Mean10_Mean20"})
        private String distribution;

        @Override
        protected double[] createProbabilities() {
            if ("Binomial_N67_P0.7".equals(this.distribution)) {
                int trials = 67;
                double probabilityOfSuccess = 0.7;
                BinomialDistribution dist = new BinomialDistribution(null, 67, 0.7);
                return KnownDistributionSources.createProbabilities((IntegerDistribution)dist, 0, 67);
            }
            if ("Geometric_P0.2".equals(this.distribution)) {
                double probabilityOfSuccess = 0.2;
                double probabilityOfFailure = 0.8;
                double p = 1.0;
                double[] probabilities = new double[100];
                double sum = 0.0;
                int k = 0;
                while (k < probabilities.length) {
                    probabilities[k] = p * 0.2;
                    if ((sum += probabilities[k++]) > 0.999999999) break;
                    p *= 0.8;
                }
                return Arrays.copyOf(probabilities, k);
            }
            if ("4SidedLoadedDie".equals(this.distribution)) {
                return new double[]{0.5, 0.3333333333333333, 0.08333333333333333, 0.08333333333333333};
            }
            if ("Poisson_Mean3.14".equals(this.distribution)) {
                double mean = 3.14;
                IntegerDistribution dist = KnownDistributionSources.createPoissonDistribution(3.14);
                int max = dist.inverseCumulativeProbability(0.999999999);
                return KnownDistributionSources.createProbabilities(dist, 0, max);
            }
            if ("Poisson_Mean10_Mean20".equals(this.distribution)) {
                double mean1 = 10.0;
                double mean2 = 20.0;
                IntegerDistribution dist1 = KnownDistributionSources.createPoissonDistribution(20.0);
                int max = dist1.inverseCumulativeProbability(0.999999999);
                double[] p1 = KnownDistributionSources.createProbabilities(dist1, 0, max);
                double[] p2 = KnownDistributionSources.createProbabilities(KnownDistributionSources.createPoissonDistribution(10.0), 0, max);
                for (int i = 0; i < p1.length; ++i) {
                    int n = i;
                    p1[n] = p1[n] + p2[i];
                }
                return p1;
            }
            throw new IllegalStateException();
        }

        private static IntegerDistribution createPoissonDistribution(double mean) {
            return new PoissonDistribution(null, mean, 1.0E-12, 10000000);
        }

        private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) {
            double[] probabilities = new double[upper - lower + 1];
            int index = 0;
            for (int x = lower; x <= upper; ++x) {
                probabilities[index++] = dist.probability(x);
            }
            return probabilities;
        }
    }

    @State(value=Scope.Benchmark)
    public static abstract class SamplerSources
    extends LocalRandomSources {
        @Param(value={"BinarySearchDiscreteSampler", "AliasMethodDiscreteSampler", "GuideTableDiscreteSampler", "MarsagliaTsangWangDiscreteSampler"})
        private String samplerType;
        private DiscreteSamplerFactory factory;
        private DiscreteSampler sampler;

        public DiscreteSampler getSampler() {
            return this.sampler;
        }

        @Override
        @Setup(value=Level.Iteration)
        public void setup() {
            super.setup();
            double[] probabilities = this.createProbabilities();
            this.createSamplerFactory(this.getGenerator(), probabilities);
            this.sampler = this.factory.create();
        }

        protected abstract double[] createProbabilities();

        private void createSamplerFactory(final UniformRandomProvider rng, final double[] probabilities) {
            if ("BinarySearchDiscreteSampler".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return new BinarySearchDiscreteSampler(rng, probabilities);
                    }
                };
            } else if ("AliasMethodDiscreteSampler".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return AliasMethodDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities);
                    }
                };
            } else if ("AliasMethodDiscreteSamplerNoPad".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return AliasMethodDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities, (int)-1);
                    }
                };
            } else if ("AliasMethodDiscreteSamplerAlpha1".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return AliasMethodDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities, (int)1);
                    }
                };
            } else if ("AliasMethodDiscreteSamplerAlpha2".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return AliasMethodDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities, (int)2);
                    }
                };
            } else if ("GuideTableDiscreteSampler".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return GuideTableDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities);
                    }
                };
            } else if ("GuideTableDiscreteSamplerAlpha2".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return GuideTableDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities, (double)2.0);
                    }
                };
            } else if ("GuideTableDiscreteSamplerAlpha8".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return GuideTableDiscreteSampler.of((UniformRandomProvider)rng, (double[])probabilities, (double)8.0);
                    }
                };
            } else if ("MarsagliaTsangWangDiscreteSampler".equals(this.samplerType)) {
                this.factory = new DiscreteSamplerFactory(){

                    @Override
                    public DiscreteSampler create() {
                        return MarsagliaTsangWangDiscreteSampler.Enumerated.of((UniformRandomProvider)rng, (double[])probabilities);
                    }
                };
            } else {
                throw new IllegalStateException();
            }
        }

        public DiscreteSampler createSampler() {
            return this.factory.create();
        }

        static interface DiscreteSamplerFactory {
            public DiscreteSampler create();
        }
    }

    @State(value=Scope.Benchmark)
    public static class LocalRandomSources {
        @Param(value={"WELL_44497_B", "ISAAC", "XO_RO_SHI_RO_128_PLUS"})
        private String randomSourceName;
        private UniformRandomProvider generator;

        public UniformRandomProvider getGenerator() {
            return this.generator;
        }

        @Setup
        public void setup() {
            RandomSource randomSource = RandomSource.valueOf((String)this.randomSourceName);
            this.generator = RandomSource.create((RandomSource)randomSource);
        }
    }
}

