/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.memory.provider;

import com.jakewharton.byteunits.BinaryByteUnit;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.MemoryWorkspaceManager;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.DebugMode;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.memory.abstracts.Nd4jWorkspace;
import org.nd4j.linalg.primitives.SynchronizedObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BasicWorkspaceManager
implements MemoryWorkspaceManager {
    private static final Logger log = LoggerFactory.getLogger(BasicWorkspaceManager.class);
    protected AtomicLong counter = new AtomicLong();
    protected WorkspaceConfiguration defaultConfiguration;
    protected ThreadLocal<Map<String, MemoryWorkspace>> backingMap = new ThreadLocal();
    protected SynchronizedObject<DebugMode> debugMode = new SynchronizedObject((Object)DebugMode.DISABLED);
    @Deprecated
    public static final String WorkspaceDeallocatorThreadName = "Workspace deallocator thread";

    public BasicWorkspaceManager() {
        this(WorkspaceConfiguration.builder().initialSize(0L).maxSize(0L).overallocationLimit(0.3).policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.EXTERNAL).build());
    }

    public BasicWorkspaceManager(@NonNull WorkspaceConfiguration defaultConfiguration) {
        if (defaultConfiguration == null) {
            throw new NullPointerException("defaultConfiguration is marked @NonNull but is null");
        }
        this.defaultConfiguration = defaultConfiguration;
    }

    public String getUUID() {
        return "Workspace_" + String.valueOf(this.counter.incrementAndGet());
    }

    public void setDefaultWorkspaceConfiguration(@NonNull WorkspaceConfiguration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        this.defaultConfiguration = configuration;
    }

    public MemoryWorkspace getWorkspaceForCurrentThread() {
        return this.getWorkspaceForCurrentThread("DefaultWorkspace");
    }

    public MemoryWorkspace getWorkspaceForCurrentThread(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        return this.getWorkspaceForCurrentThread(this.defaultConfiguration, id);
    }

    public DebugMode getDebugMode() {
        return (DebugMode)this.debugMode.get();
    }

    public void setDebugMode(DebugMode mode) {
        if (mode == null) {
            mode = DebugMode.DISABLED;
        }
        this.debugMode.set((Object)mode);
    }

    protected abstract void pickReference(MemoryWorkspace var1);

    public void setWorkspaceForCurrentThread(MemoryWorkspace workspace) {
        this.setWorkspaceForCurrentThread(workspace, "DefaultWorkspace");
    }

    public void setWorkspaceForCurrentThread(@NonNull MemoryWorkspace workspace, @NonNull String id) {
        if (workspace == null) {
            throw new NullPointerException("workspace is marked @NonNull but is null");
        }
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        this.ensureThreadExistense();
        this.backingMap.get().put(id, workspace);
    }

    public void destroyWorkspace(MemoryWorkspace workspace) {
        if (workspace == null || workspace instanceof DummyWorkspace) {
            return;
        }
        this.backingMap.get().remove(workspace.getId());
    }

    public void destroyWorkspace() {
        this.ensureThreadExistense();
        MemoryWorkspace workspace = this.backingMap.get().get("DefaultWorkspace");
        this.backingMap.get().remove("DefaultWorkspace");
    }

    public void destroyAllWorkspacesForCurrentThread() {
        this.ensureThreadExistense();
        ArrayList<MemoryWorkspace> workspaces = new ArrayList<MemoryWorkspace>();
        workspaces.addAll(this.backingMap.get().values());
        for (MemoryWorkspace workspace : workspaces) {
            this.destroyWorkspace(workspace);
        }
        Nd4j.getMemoryManager().invokeGc();
    }

    protected void ensureThreadExistense() {
        if (this.backingMap.get() == null) {
            this.backingMap.set(new HashMap());
        }
    }

    public MemoryWorkspace getAndActivateWorkspace() {
        return this.getWorkspaceForCurrentThread().notifyScopeEntered();
    }

    public MemoryWorkspace getAndActivateWorkspace(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        return this.getWorkspaceForCurrentThread(id).notifyScopeEntered();
    }

    public MemoryWorkspace getAndActivateWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String id) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        return this.getWorkspaceForCurrentThread(configuration, id).notifyScopeEntered();
    }

    public boolean checkIfWorkspaceExists(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        this.ensureThreadExistense();
        return this.backingMap.get().containsKey(id);
    }

    public boolean checkIfWorkspaceExistsAndActive(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        boolean exists = this.checkIfWorkspaceExists(id);
        if (!exists) {
            return false;
        }
        return this.backingMap.get().get(id).isScopeActive();
    }

    public MemoryWorkspace scopeOutOfWorkspaces() {
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (workspace == null) {
            return new DummyWorkspace();
        }
        Nd4j.getMemoryManager().setCurrentWorkspace(null);
        return workspace.tagOutOfScopeUse();
    }

    public synchronized void printAllocationStatisticsForCurrentThread() {
        this.ensureThreadExistense();
        Map<String, MemoryWorkspace> map = this.backingMap.get();
        log.info("Workspace statistics: ---------------------------------");
        log.info("Number of workspaces in current thread: {}", (Object)map.size());
        log.info("Workspace name: Allocated / external (spilled) / external (pinned)");
        for (String key : map.keySet()) {
            long current = ((Nd4jWorkspace)map.get(key)).getCurrentSize();
            long spilled = ((Nd4jWorkspace)map.get(key)).getSpilledSize();
            long pinned = ((Nd4jWorkspace)map.get(key)).getPinnedSize();
            log.info(String.format("%-26s %8s / %8s / %8s (%11d / %11d / %11d)", key + ":", BinaryByteUnit.format((long)current, (String)"#.00"), BinaryByteUnit.format((long)spilled, (String)"#.00"), BinaryByteUnit.format((long)pinned, (String)"#.00"), current, spilled, pinned));
        }
    }

    public List<String> getAllWorkspacesIdsForCurrentThread() {
        this.ensureThreadExistense();
        return new ArrayList<String>(this.backingMap.get().keySet());
    }

    public List<MemoryWorkspace> getAllWorkspacesForCurrentThread() {
        this.ensureThreadExistense();
        return new ArrayList<MemoryWorkspace>(this.backingMap.get().values());
    }

    public boolean anyWorkspaceActiveForCurrentThread() {
        this.ensureThreadExistense();
        boolean anyActive = false;
        for (MemoryWorkspace ws : this.backingMap.get().values()) {
            if (!ws.isScopeActive()) continue;
            anyActive = true;
            break;
        }
        return anyActive;
    }
}

