/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.scalar;

import java.util.ArrayList;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.scalar.CPUScalarOpAction;

public class CPUScalarOpViaTensorAction
extends BaseCPUAction {
    protected final ScalarOp op;

    public CPUScalarOpViaTensorAction(ScalarOp op, int threshold) {
        super(op, threshold);
        this.op = op;
    }

    @Override
    public Void call() {
        this.execute(false);
        return null;
    }

    @Override
    protected void compute() {
        this.execute(true);
    }

    private void execute(boolean forkJoin) {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z = this.op.z();
        int tensorDim = y == null ? (x == z ? OpExecutionerUtil.chooseElementWiseTensorDimension(x) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, z)) : (x == z ? OpExecutionerUtil.chooseElementWiseTensorDimension(x, y) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, y, z));
        int nTensors = x.tensorssAlongDimension(tensorDim);
        ArrayList<CPUScalarOpAction> fjTasks = null;
        if (forkJoin) {
            fjTasks = new ArrayList<CPUScalarOpAction>(nTensors);
        } else {
            this.subTasks = new ArrayList(nTensors);
        }
        if (nTensors == 1) {
            CPUScalarOpAction task = new CPUScalarOpAction(this.op, this.threshold);
            if (forkJoin) {
                task.invoke();
                return;
            }
            task.invokeAsync();
            this.subTasks.add(task);
        } else if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats tsx = OpExecutionerUtil.get1DTensorStats(x, tensorDim);
            int n = tsx.getTensorLength();
            int incrX = tsx.getElementWiseStride();
            if (x == z) {
                for (int i = 0; i < nTensors; ++i) {
                    int offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                    CPUScalarOpAction task = new CPUScalarOpAction(this.op, this.threshold, n, offsetX, offsetX, incrX, incrX);
                    if (forkJoin) {
                        task.fork();
                        fjTasks.add(task);
                        continue;
                    }
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tsz = OpExecutionerUtil.get1DTensorStats(z, tensorDim);
                int incrZ = tsz.getElementWiseStride();
                for (int i = 0; i < nTensors; ++i) {
                    int offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                    int offsetZ = tsz.getFirstTensorOffset() + i * tsz.getTensorStartSeparation();
                    CPUScalarOpAction task = new CPUScalarOpAction(this.op, this.threshold, n, offsetX, offsetZ, incrX, incrZ);
                    if (forkJoin) {
                        task.fork();
                        fjTasks.add(task);
                        continue;
                    }
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            }
        } else {
            for (int i = 0; i < nTensors; ++i) {
                CPUScalarOpAction cPUScalarOpAction = new CPUScalarOpAction(this.op, this.threshold, i, tensorDim);
                if (forkJoin) {
                    cPUScalarOpAction.fork();
                    fjTasks.add(cPUScalarOpAction);
                    continue;
                }
                cPUScalarOpAction.invokeAsync();
                this.subTasks.add(cPUScalarOpAction);
            }
        }
        if (forkJoin) {
            for (RecursiveAction recursiveAction : fjTasks) {
                recursiveAction.join();
            }
        }
    }
}

