/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu;

import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.util.ArrayUtil;

public abstract class BaseCPUTask<V>
extends RecursiveTask<V>
implements Task<V> {
    protected final int threshold;
    protected int n;
    protected int offsetX;
    protected int offsetY;
    protected int offsetZ;
    protected int incrX;
    protected int incrY;
    protected int incrZ;
    protected boolean doTensorFirst;
    protected int tensorIdx;
    protected int tensorDim;
    protected boolean executed = false;
    protected Future<V> future;

    public BaseCPUTask(int threshold, int n, int offsetX, int offsetY, int offsetZ, int incrX, int incrY, int incrZ) {
        this.threshold = threshold;
        this.n = n;
        this.offsetX = offsetX;
        this.offsetY = offsetY;
        this.offsetZ = offsetZ;
        this.incrX = incrX;
        this.incrY = incrY;
        this.incrZ = incrZ;
        this.doTensorFirst = false;
    }

    public BaseCPUTask(Op op, int threshold) {
        this.threshold = threshold;
        this.n = op.x().length();
        this.offsetX = op.x().offset();
        this.offsetY = op.y() != null ? op.y().offset() : 0;
        this.offsetZ = op.z() != null ? op.z().offset() : 0;
        this.incrX = op.x().elementWiseStride();
        this.incrY = op.y() != null ? op.y().elementWiseStride() : 0;
        this.incrZ = op.z() != null ? op.z().elementWiseStride() : 0;
        this.doTensorFirst = false;
        if (this.incrX == -1) {
            INDArray reshapeX = op.x().reshape(new int[]{1, ArrayUtil.prod((int[])op.x().shape())});
            this.incrX = reshapeX.stride(1);
        }
        if (this.incrY == -1) {
            if (op.y() == op.x()) {
                this.incrY = this.incrX;
            } else {
                INDArray reshapeY = op.y().reshape(new int[]{1, ArrayUtil.prod((int[])op.y().shape())});
                this.incrY = reshapeY.stride(1);
            }
        }
        if (this.incrZ == -1) {
            if (op.z() == op.x()) {
                this.incrZ = this.incrX;
            } else {
                INDArray reshapeZ = op.z().reshape(new int[]{1, ArrayUtil.prod((int[])op.z().shape())});
                this.incrY = reshapeZ.stride(1);
            }
        }
    }

    public BaseCPUTask(Op op, int threshold, int tadIdx, int tadDim) {
        this.doTensorFirst = true;
        this.threshold = threshold;
        this.tensorIdx = tadIdx;
        this.tensorDim = tadDim;
    }

    protected void doTensorFirst(Op op) {
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        INDArray tadx = x.tensorAlongDimension(this.tensorIdx, this.tensorDim);
        this.n = tadx.length();
        this.offsetX = tadx.offset();
        this.incrX = tadx.elementWiseStride();
        if (this.incrX < 0) {
            x = op.x().dup();
            tadx = x.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.incrX = tadx.elementWiseStride();
            if (this.incrX < 0) {
                throw new IllegalStateException("Illegal x input unable to use element wise stride for dimension");
            }
        }
        if (y == null) {
            this.offsetY = 0;
            this.incrY = 0;
        } else if (y == x) {
            this.offsetY = this.offsetX;
            this.incrY = this.incrX;
        } else {
            INDArray tady = y.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.offsetY = tady.offset();
            this.incrY = tady.elementWiseStride();
        }
        if (z == null) {
            this.offsetZ = 0;
            this.incrZ = 0;
        } else if (z == x) {
            this.offsetZ = this.offsetX;
            this.incrZ = this.incrX;
        } else if (z == y) {
            this.offsetZ = this.offsetY;
            this.incrZ = this.incrY;
        } else {
            INDArray tadz = z.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.offsetZ = tadz.offset();
            this.incrZ = tadz.elementWiseStride();
        }
    }

    @Override
    public void invokeAsync() {
        this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
    }

    @Override
    public V invokeBlocking() {
        this.invokeAsync();
        return this.blockUntilComplete();
    }
}

