/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.accumulation;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveTask;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationTask;

public class CPUAccumulationViaTensorTask
extends BaseCPUTask<Double> {
    protected final Accumulation op;
    protected final boolean outerTask;
    protected List<Task<Double>> subTasks;

    public CPUAccumulationViaTensorTask(Accumulation op, int threshold, boolean outerTask) {
        super(op, threshold);
        this.op = op;
        this.outerTask = outerTask;
    }

    @Override
    public Double blockUntilComplete() {
        Double accum;
        if (this.future == null) {
            this.invokeAsync();
        }
        try {
            accum = (Double)this.future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (this.subTasks != null) {
            accum = this.op.zeroDouble();
            for (Task<Double> task : this.subTasks) {
                double subAccum = task.blockUntilComplete();
                accum = this.op.combineSubResults(accum, subAccum);
            }
        }
        if (this.outerTask && this.subTasks != null) {
            return this.op.getAndSetFinalResult(accum);
        }
        return accum;
    }

    @Override
    public Double call() {
        return this.execute(false);
    }

    @Override
    protected Double compute() {
        double out = this.execute(true);
        if (this.outerTask) {
            return this.op.getAndSetFinalResult(out);
        }
        return out;
    }

    /*
     * WARNING - void declaration
     */
    private Double execute(boolean forkJoin) {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        int tensorDim = y == null ? OpExecutionerUtil.chooseElementWiseTensorDimension(x) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, y);
        int nTensors = x.tensorssAlongDimension(tensorDim);
        ArrayList<CPUAccumulationTask> fjTasks = null;
        if (forkJoin) {
            fjTasks = new ArrayList<CPUAccumulationTask>(nTensors);
        } else {
            this.subTasks = new ArrayList<Task<Double>>(nTensors);
        }
        if (nTensors == 1) {
            CPUAccumulationTask task = new CPUAccumulationTask(this.op, this.threshold, false);
            if (forkJoin) {
                return (Double)task.invoke();
            }
            task.invokeAsync();
            this.subTasks.add(task);
            return null;
        }
        if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats tsx = OpExecutionerUtil.get1DTensorStats(x, tensorDim);
            int n = tsx.getTensorLength();
            int incrX = tsx.getElementWiseStride();
            if (y == null) {
                void var10_16;
                boolean bl = false;
                while (var10_16 < nTensors) {
                    int offsetX = tsx.getFirstTensorOffset() + var10_16 * tsx.getTensorStartSeparation();
                    CPUAccumulationTask task = new CPUAccumulationTask(this.op, this.threshold, n, offsetX, 0, incrX, 0, false);
                    if (forkJoin) {
                        task.fork();
                        fjTasks.add(task);
                    } else {
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                    ++var10_16;
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tensor1DStats = OpExecutionerUtil.get1DTensorStats(y, tensorDim);
                int incrY = tensor1DStats.getElementWiseStride();
                for (int i = 0; i < nTensors; ++i) {
                    int offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                    int offsetY = tensor1DStats.getFirstTensorOffset() + i * tensor1DStats.getTensorStartSeparation();
                    CPUAccumulationTask task = new CPUAccumulationTask(this.op, this.threshold, n, offsetX, offsetY, incrX, incrY, false);
                    if (forkJoin) {
                        task.fork();
                        fjTasks.add(task);
                        continue;
                    }
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            }
        } else {
            for (int i = 0; i < nTensors; ++i) {
                CPUAccumulationTask task = new CPUAccumulationTask(this.op, this.threshold, i, tensorDim, false);
                if (forkJoin) {
                    task.fork();
                    fjTasks.add(task);
                    continue;
                }
                task.invokeAsync();
                this.subTasks.add(task);
            }
        }
        if (forkJoin) {
            double accum = this.op.zeroDouble();
            for (RecursiveTask recursiveTask : fjTasks) {
                accum = this.op.combineSubResults(accum, (Double)recursiveTask.join());
            }
            return accum;
        }
        return null;
    }
}

