/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.utils;

import ai.djl.basicdataset.TextDataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Sampler;
import ai.djl.util.RandomUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

public class FixedBucketSampler
implements Sampler {
    private Set<Bucket> buckets;
    private int numBuckets;
    private int batchSize;
    private boolean dropLast;
    private boolean shuffle;

    public FixedBucketSampler(int batchSize, int numBuckets, boolean dropLast, boolean shuffle) {
        this.numBuckets = numBuckets;
        this.batchSize = batchSize;
        this.dropLast = dropLast;
        this.shuffle = shuffle;
    }

    public FixedBucketSampler(int batchSize, int numBuckets) {
        this(numBuckets, batchSize, false, true);
    }

    public FixedBucketSampler(int batchSize) {
        this(10, batchSize);
    }

    public Iterator<List<Long>> sample(RandomAccessDataset dataset) {
        return new Iterate(dataset);
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    private class Iterate
    implements Iterator<List<Long>> {
        private long current;
        private long size;

        public Iterate(RandomAccessDataset dataset) {
            this.size = FixedBucketSampler.this.dropLast ? dataset.size() / (long)FixedBucketSampler.this.batchSize : (dataset.size() + (long)FixedBucketSampler.this.batchSize - 1L) / (long)FixedBucketSampler.this.batchSize;
            if (!(dataset instanceof TextDataset)) {
                throw new IllegalStateException("FixedBucketSampler can only be used with TextDataset");
            }
            if (FixedBucketSampler.this.buckets == null) {
                ArrayList<Sample> samples = new ArrayList<Sample>();
                int i = 0;
                while ((long)i < dataset.size()) {
                    samples.add(new Sample(i, ((TextDataset)dataset).getProcessedText(i, true).size()));
                    ++i;
                }
                samples.sort(Comparator.comparingInt(o -> o.sentenceLength));
                FixedBucketSampler.this.buckets = new TreeSet<Bucket>(Comparator.comparingInt(o -> o.index));
                int bucketSize = samples.size() / FixedBucketSampler.this.numBuckets;
                int bucketNumber = 0;
                for (int i2 = 0; i2 < samples.size(); i2 += bucketSize) {
                    int end = i2 + bucketSize;
                    if (end > samples.size()) {
                        end = samples.size();
                    }
                    FixedBucketSampler.this.buckets.add(new Bucket(bucketNumber++, new HashSet<Sample>(samples.subList(i2, end))));
                }
            }
        }

        @Override
        public boolean hasNext() {
            return this.current < this.size;
        }

        @Override
        public List<Long> next() {
            int collected = 0;
            ArrayList allSamples = new ArrayList();
            Iterator<Bucket> iterator = FixedBucketSampler.this.buckets.iterator();
            Bucket bucket = this.firstBucket(iterator);
            while (collected < FixedBucketSampler.this.batchSize) {
                Set<Sample> samples = bucket.samples;
                ArrayList<Sample> bucketSamples = new ArrayList<Sample>();
                for (Sample sample : samples) {
                    bucketSamples.add(sample);
                    if (++collected < FixedBucketSampler.this.batchSize) continue;
                    break;
                }
                for (Sample sample : bucketSamples) {
                    samples.remove(sample);
                }
                allSamples.addAll(bucketSamples);
                if (collected >= FixedBucketSampler.this.batchSize) break;
                if (!iterator.hasNext()) {
                    if (FixedBucketSampler.this.shuffle) {
                        iterator = FixedBucketSampler.this.buckets.iterator();
                    } else {
                        throw new IllegalStateException("Code should never reach here");
                    }
                }
                bucket = iterator.next();
            }
            ArrayList<Long> next = new ArrayList<Long>();
            for (Sample sample : allSamples) {
                next.add(sample.index);
            }
            ++this.current;
            return next;
        }

        private Bucket firstBucket(Iterator<Bucket> iterator) {
            if (FixedBucketSampler.this.shuffle) {
                int firstIndex = RandomUtils.nextInt((int)FixedBucketSampler.this.buckets.size());
                for (int i = 0; i < firstIndex - 1; ++i) {
                    iterator.next();
                }
                return iterator.next();
            }
            return iterator.next();
        }
    }

    private static class Bucket {
        Set<Sample> samples;
        int index;

        public Bucket(int index, Set<Sample> samples) {
            this.index = index;
            this.samples = samples;
        }
    }

    private static class Sample {
        int sentenceLength;
        long index;

        public Sample(int index, int sentenceLength) {
            this.index = index;
            this.sentenceLength = sentenceLength;
        }
    }
}

