/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.transform;

import java.util.ArrayList;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpAction;

public class CPUTransformOpViaTensorTask
extends BaseCPUAction {
    protected final TransformOp op;

    public CPUTransformOpViaTensorTask(TransformOp op, int threshold) {
        super(threshold, 0, 0, 0, 0, 0, 0, 0);
        this.op = op;
    }

    @Override
    public Void call() {
        this.execute(false);
        return null;
    }

    @Override
    protected void compute() {
        this.execute(true);
    }

    private void execute(boolean forkJoin) {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z = this.op.z();
        int tensorDim = y == null ? (x == z ? OpExecutionerUtil.chooseElementWiseTensorDimension(x) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, z)) : (x == z ? OpExecutionerUtil.chooseElementWiseTensorDimension(x, y) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, y, z));
        int nTensors = x.tensorssAlongDimension(tensorDim);
        ArrayList<CPUTransformOpAction> fjTasks = null;
        if (forkJoin) {
            fjTasks = new ArrayList<CPUTransformOpAction>(nTensors);
        } else {
            this.subTasks = new ArrayList(nTensors);
        }
        if (nTensors == 1) {
            CPUTransformOpAction task = new CPUTransformOpAction(this.op, this.threshold);
            if (forkJoin) {
                task.invoke();
            } else {
                task.invokeAsync();
                this.subTasks.add(task);
            }
            return;
        }
        if (x.rank() == 2) {
            int offsetX;
            OpExecutionerUtil.Tensor1DStats tsx = OpExecutionerUtil.get1DTensorStats(x, tensorDim);
            int n = tsx.getTensorLength();
            int incrX = tsx.getElementWiseStride();
            if (y == null) {
                if (x == z) {
                    for (int i = 0; i < nTensors; ++i) {
                        int offsetX2 = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                        CPUTransformOpAction task = new CPUTransformOpAction(this.op, this.threshold, n, offsetX2, 0, offsetX2, incrX, 0, incrX);
                        if (forkJoin) {
                            task.fork();
                            fjTasks.add(task);
                            continue;
                        }
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                } else {
                    OpExecutionerUtil.Tensor1DStats tsz = OpExecutionerUtil.get1DTensorStats(z, tensorDim);
                    int incrZ = tsz.getElementWiseStride();
                    for (int i = 0; i < nTensors; ++i) {
                        offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                        int offsetZ = tsz.getFirstTensorOffset() + i * tsz.getTensorStartSeparation();
                        CPUTransformOpAction task = new CPUTransformOpAction(this.op, this.threshold, n, offsetX, 0, offsetZ, incrX, 0, incrZ);
                        if (forkJoin) {
                            task.fork();
                            fjTasks.add(task);
                            continue;
                        }
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tsy = OpExecutionerUtil.get1DTensorStats(y, tensorDim);
                int incrY = tsy.elementWiseStride;
                if (x == z) {
                    for (int i = 0; i < nTensors; ++i) {
                        offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                        int offsetY = tsy.getFirstTensorOffset() + i * tsy.getTensorStartSeparation();
                        CPUTransformOpAction task = new CPUTransformOpAction(this.op, this.threshold, n, offsetX, offsetY, offsetX, incrX, incrY, incrX);
                        if (forkJoin) {
                            task.fork();
                            fjTasks.add(task);
                            continue;
                        }
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                } else {
                    OpExecutionerUtil.Tensor1DStats tsz = OpExecutionerUtil.get1DTensorStats(z, tensorDim);
                    int incrZ = tsz.getElementWiseStride();
                    for (int i = 0; i < nTensors; ++i) {
                        int offsetX3 = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                        int offsetY = tsy.getFirstTensorOffset() + i * tsy.getTensorStartSeparation();
                        int offsetZ = tsz.getFirstTensorOffset() + i * tsz.getTensorStartSeparation();
                        CPUTransformOpAction task = new CPUTransformOpAction(this.op, this.threshold, n, offsetX3, offsetY, offsetZ, incrX, incrY, incrZ);
                        if (forkJoin) {
                            task.fork();
                            fjTasks.add(task);
                            continue;
                        }
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                }
            }
        } else {
            for (int i = 0; i < nTensors; ++i) {
                CPUTransformOpAction cPUTransformOpAction = new CPUTransformOpAction(this.op, this.threshold, i, tensorDim);
                if (forkJoin) {
                    cPUTransformOpAction.fork();
                    fjTasks.add(cPUTransformOpAction);
                    continue;
                }
                cPUTransformOpAction.invokeAsync();
                this.subTasks.add(cPUTransformOpAction);
            }
        }
        if (forkJoin) {
            for (RecursiveAction recursiveAction : fjTasks) {
                recursiveAction.join();
            }
        }
    }
}

