/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator.callbacks;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InterleavedDataSetCallback
implements DataSetCallback {
    private static final Logger log = LoggerFactory.getLogger(InterleavedDataSetCallback.class);
    private List<MemoryWorkspace> workspaces = new ArrayList<MemoryWorkspace>();
    private int bufferSize;
    private int numWorkspaces;
    private boolean isInitialized = false;
    private AtomicLong counterInput = new AtomicLong(0L);

    public InterleavedDataSetCallback(int bufferSize) {
        this.bufferSize = bufferSize;
    }

    protected void initializeWorkspaces(long size) {
        WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(size).overallocationLimit((double)this.bufferSize).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE).build();
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int cDevice = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        for (int i = 0; i < numDevices; ++i) {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(i));
            this.workspaces.add(Nd4j.getWorkspaceManager().createNewWorkspace(configuration, "IDSC-" + i, Integer.valueOf(i)));
        }
        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(cDevice));
        this.numWorkspaces = numDevices;
        this.isInitialized = true;
    }

    @Override
    public void call(DataSet dataSet) {
        if (!this.isInitialized) {
            this.initializeWorkspaces(dataSet.getMemoryFootprint());
        }
        Nd4j.getExecutioner().commit();
        int currIdx = (int)(this.counterInput.getAndIncrement() % (long)this.numWorkspaces);
        MemoryWorkspace currWs = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(this.workspaces.get(currIdx));
        dataSet.migrate();
        Nd4j.getMemoryManager().setCurrentWorkspace(currWs);
    }

    @Override
    public void call(MultiDataSet multiDataSet) {
        if (!this.isInitialized) {
            this.initializeWorkspaces(multiDataSet.getMemoryFootprint());
        }
        Nd4j.getExecutioner().commit();
        int currIdx = (int)(this.counterInput.getAndIncrement() % (long)this.numWorkspaces);
        MemoryWorkspace currWs = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(this.workspaces.get(currIdx));
        multiDataSet.migrate();
        Nd4j.getMemoryManager().setCurrentWorkspace(currWs);
    }

    @Override
    public void reset() {
        this.counterInput.set(0L);
    }
}

