/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;

public class Batch
implements AutoCloseable {
    private NDManager manager;
    private NDList data;
    private NDList labels;
    private Batchifier dataBatchifier;
    private Batchifier labelBatchifier;
    private int size;
    private long progress;
    private long progressTotal;

    public Batch(NDManager manager, NDList data, NDList labels, int size, Batchifier dataBatchifier, Batchifier labelBatchifier) {
        this.manager = manager;
        data.attach(manager);
        labels.attach(manager);
        this.data = data;
        this.labels = labels;
        this.size = size;
        this.dataBatchifier = dataBatchifier;
        this.labelBatchifier = labelBatchifier;
    }

    public Batch(NDManager manager, NDList data, NDList labels, int size, Batchifier dataBatchifier, Batchifier labelBatchifier, long progress, long progressTotal) {
        this.manager = manager;
        data.attach(manager);
        labels.attach(manager);
        this.data = data;
        this.labels = labels;
        this.size = size;
        this.dataBatchifier = dataBatchifier;
        this.labelBatchifier = labelBatchifier;
        this.progress = progress;
        this.progressTotal = progressTotal;
    }

    public NDManager getManager() {
        return this.manager;
    }

    public NDList getData() {
        return this.data;
    }

    public NDList getLabels() {
        return this.labels;
    }

    public int getSize() {
        return this.size;
    }

    public long getProgress() {
        return this.progress;
    }

    public long getProgressTotal() {
        return this.progressTotal;
    }

    @Override
    public void close() {
        this.manager.close();
        this.manager = null;
    }

    public Batch[] split(Device[] devices, boolean evenSplit) {
        int deviceCount = devices.length;
        if (deviceCount == 1) {
            if (this.data.head().getDevice().equals(devices[0])) {
                return new Batch[]{new Batch(this.manager, this.data, this.labels, this.size, this.dataBatchifier, this.labelBatchifier, this.progress, this.progressTotal)};
            }
            NDList d = this.data.asInDevice(devices[0], true);
            NDList l = this.labels.asInDevice(devices[0], true);
            return new Batch[]{new Batch(this.manager, d, l, this.size, this.dataBatchifier, this.labelBatchifier, this.progress, this.progressTotal)};
        }
        NDList[] splittedData = this.split(this.data, this.dataBatchifier, deviceCount, evenSplit);
        NDList[] splittedLabels = this.split(this.labels, this.labelBatchifier, deviceCount, evenSplit);
        Batch[] splitted = new Batch[splittedData.length];
        int baseSplitSize = this.size / deviceCount;
        for (int i = 0; i < splittedData.length; ++i) {
            NDList d = splittedData[i].asInDevice(devices[i], true);
            NDList l = splittedLabels[i].asInDevice(devices[i], true);
            int subSize = i == splittedData.length - 1 ? this.size - i * baseSplitSize : baseSplitSize;
            splitted[i] = new Batch(this.manager, d, l, subSize, this.dataBatchifier, this.labelBatchifier, this.progress, this.progressTotal);
        }
        return splitted;
    }

    private NDList[] split(NDList list, Batchifier batchifier, int numOfSlices, boolean evenSplit) {
        if (batchifier == null) {
            throw new IllegalStateException("Split can only be called on a batch containing a batchifier");
        }
        return batchifier.split(list, numOfSlices, evenSplit);
    }
}

