/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveTask;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationTask;

public class CPUIndexAccumulationViaTensorTask
extends BaseCPUTask<Pair<Double, Integer>> {
    protected final IndexAccumulation op;
    protected final boolean outerTask;
    protected List<Task<Pair<Double, Integer>>> subTasks;

    public CPUIndexAccumulationViaTensorTask(IndexAccumulation op, int threshold, boolean outerTask) {
        super(op, threshold);
        this.op = op;
        this.outerTask = outerTask;
    }

    @Override
    public Pair<Double, Integer> blockUntilComplete() {
        Pair<Double, Integer> accum;
        if (this.future == null) {
            this.invokeAsync();
        }
        try {
            accum = (Pair<Double, Integer>)this.future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (accum == null) {
            accum = this.op.zeroPair();
            for (Task<Pair<Double, Integer>> task : this.subTasks) {
                Pair<Double, Integer> subAccum = task.blockUntilComplete();
                accum = this.op.combineSubResults(accum, subAccum);
            }
        }
        if (this.outerTask) {
            this.op.setFinalResult((Integer)accum.getSecond());
        }
        return accum;
    }

    @Override
    public Pair<Double, Integer> call() {
        return this.execute(false);
    }

    @Override
    protected Pair<Double, Integer> compute() {
        return this.execute(true);
    }

    private Pair<Double, Integer> execute(boolean forkJoin) {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        int tensorDim = 1;
        int nTensors = x.tensorssAlongDimension(tensorDim);
        ArrayList<CPUIndexAccumulationTask> fjTasks = null;
        if (forkJoin) {
            fjTasks = new ArrayList<CPUIndexAccumulationTask>(nTensors);
        } else {
            this.subTasks = new ArrayList<Task<Pair<Double, Integer>>>(nTensors);
        }
        if (nTensors == 1) {
            CPUIndexAccumulationTask task = new CPUIndexAccumulationTask(this.op, this.threshold, false);
            return (Pair)task.invoke();
        }
        if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats tsx = OpExecutionerUtil.get1DTensorStats(x, tensorDim);
            int n = tsx.getTensorLength();
            int n2 = tsx.getElementWiseStride();
            if (y == null) {
                for (int i = 0; i < nTensors; ++i) {
                    int offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                    int elementOffset = i * tsx.getTensorLength();
                    CPUIndexAccumulationTask task = new CPUIndexAccumulationTask(this.op, this.threshold, n, offsetX, 0, n2, 0, elementOffset, false);
                    if (forkJoin) {
                        task.fork();
                        fjTasks.add(task);
                        continue;
                    }
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tsy = OpExecutionerUtil.get1DTensorStats(y, tensorDim);
                int incrY = tsy.getElementWiseStride();
                for (int i = 0; i < nTensors; ++i) {
                    int offsetX = tsx.getFirstTensorOffset() + i * tsx.getTensorStartSeparation();
                    int offsetY = tsy.getFirstTensorOffset() + i * tsy.getTensorStartSeparation();
                    int elementOffset = i * tsx.getTensorLength();
                    CPUIndexAccumulationTask task = new CPUIndexAccumulationTask(this.op, this.threshold, n, offsetX, offsetY, n2, incrY, elementOffset, false);
                    if (forkJoin) {
                        task.fork();
                        fjTasks.add(task);
                        continue;
                    }
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            }
        } else {
            for (int i = 0; i < nTensors; ++i) {
                CPUIndexAccumulationTask task = new CPUIndexAccumulationTask(this.op, this.threshold, i, tensorDim, false);
                if (forkJoin) {
                    task.fork();
                    fjTasks.add(task);
                    continue;
                }
                task.invokeAsync();
                this.subTasks.add(task);
            }
        }
        if (forkJoin) {
            Pair<Double, Integer> accum = this.op.zeroPair();
            for (RecursiveTask recursiveTask : fjTasks) {
                Pair subAccum = (Pair)recursiveTask.join();
                accum = this.op.combineSubResults(accum, (Pair<Double, Integer>)subAccum);
            }
            return accum;
        }
        return null;
    }
}

