/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;

public class Batch
implements AutoCloseable {
    private NDManager manager;
    private NDList data;
    private NDList labels;
    private Batchifier batchifier;

    public Batch(NDManager manager, NDList data, NDList labels) {
        this.manager = manager;
        data.attach(manager);
        labels.attach(manager);
        this.data = data;
        this.labels = labels;
    }

    public Batch(NDManager manager, NDList data, NDList labels, Batchifier batchifier) {
        this.manager = manager;
        data.attach(manager);
        labels.attach(manager);
        this.data = data;
        this.labels = labels;
        this.batchifier = batchifier;
    }

    public NDManager getManager() {
        return this.manager;
    }

    public NDList getData() {
        return this.data;
    }

    public NDList getLabels() {
        return this.labels;
    }

    @Override
    public void close() {
        this.manager.close();
        this.manager = null;
    }

    public Batch[] split(Device[] devices, boolean evenSplit) {
        int size = devices.length;
        if (size == 1) {
            if (this.data.head().getDevice().equals(devices[0])) {
                return new Batch[]{new Batch(this.manager, this.data, this.labels, this.batchifier)};
            }
            NDList d = this.data.asInDevice(devices[0], true);
            NDList l = this.labels.asInDevice(devices[0], true);
            return new Batch[]{new Batch(this.manager, d, l, this.batchifier)};
        }
        NDList[] splittedData = this.split(this.data, size, evenSplit);
        NDList[] splittedLabels = this.split(this.labels, size, evenSplit);
        Batch[] splitted = new Batch[splittedData.length];
        for (int i = 0; i < splittedData.length; ++i) {
            NDList d = splittedData[i].asInDevice(devices[i], true);
            NDList l = splittedLabels[i].asInDevice(devices[i], true);
            splitted[i] = new Batch(this.manager, d, l, this.batchifier);
        }
        return splitted;
    }

    private NDList[] split(NDList list, int numOfSlices, boolean evenSplit) {
        if (this.batchifier == null) {
            throw new IllegalStateException("Split can only be called on a batch containing a batchifier");
        }
        return this.batchifier.split(list, numOfSlices, evenSplit);
    }
}

