/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.accumulation;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.BaseTask;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulations1dAction;
import org.nd4j.linalg.api.shape.tensor.TensorCalculator;
import org.nd4j.linalg.api.shape.tensor.TensorCalculatorFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class CPUAccumulationAlongDimensionTask
extends BaseCPUTask<INDArray> {
    protected final Accumulation op;
    protected final int[] dimensions;
    protected List<Task<Double>> subTasks;

    public CPUAccumulationAlongDimensionTask(Accumulation op, int parallelThreshold, int ... dimensions) {
        super(op, parallelThreshold);
        for (int i = 0; i < dimensions.length; ++i) {
            if (dimensions[i] >= 0) continue;
            int n = i;
            dimensions[n] = dimensions[n] + op.x().rank();
        }
        this.op = op;
        this.dimensions = dimensions;
    }

    @Override
    public INDArray blockUntilComplete() {
        INDArray ret;
        if (this.future == null) {
            this.invokeAsync();
        }
        try {
            ret = (INDArray)this.future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (ret != null) {
            if (this.dimensions.length == 1 && this.dimensions[0] == 1 && this.op.x().isMatrix()) {
                ret = ret.reshape(ret.length(), 1);
            }
            return ret;
        }
        int[] retShape = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
        if (this.dimensions.length == 1 && this.dimensions[0] == 1 && this.op.x().isMatrix()) {
            retShape = new int[]{this.op.x().length(), 1};
        }
        INDArray out = Nd4j.create(retShape);
        int i = 0;
        for (Task<Double> task : this.subTasks) {
            out.putScalar(i++, (double)task.blockUntilComplete());
        }
        this.op.setZ(out);
        return out;
    }

    @Override
    public INDArray call() {
        int nTensors = this.op.x().tensorssAlongDimension(this.dimensions);
        this.subTasks = new ArrayList<Task<Double>>(nTensors);
        for (int i = 0; i < nTensors; ++i) {
            OpForDimTask task = new OpForDimTask(i);
            task.invokeAsync();
            this.subTasks.add(task);
        }
        return null;
    }

    @Override
    public INDArray compute() {
        if (this.dimensions.length == 1 && !this.op.isPassThrough()) {
            TensorCalculator tCalcx = TensorCalculatorFactory.getTensorCalculator(this.op.x(), this.dimensions[0]);
            TensorCalculator tCalcy = this.op.y() != null ? TensorCalculatorFactory.getTensorCalculator(this.op.y(), this.dimensions[0]) : null;
            int[] retShape = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
            INDArray out = Nd4j.create(retShape);
            CPUAccumulations1dAction action = new CPUAccumulations1dAction(this.op, this.threshold, tCalcx, tCalcy, 0, tCalcx.getNumTensors() - 1, out);
            action.invoke();
            this.op.setZ(out);
            return out;
        }
        int nTensors = this.op.x().tensorssAlongDimension(this.dimensions);
        ArrayList<OpForDimTaskFJ> subTasks = new ArrayList<OpForDimTaskFJ>(nTensors);
        for (int i = 0; i < nTensors; ++i) {
            OpForDimTaskFJ task = new OpForDimTaskFJ(i);
            task.fork();
            subTasks.add(task);
        }
        int[] retShape = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
        INDArray out = Nd4j.create(retShape);
        int i = 0;
        for (RecursiveTask recursiveTask : subTasks) {
            out.putScalar(i++, (double)((Double)recursiveTask.join()));
        }
        this.op.setZ(out);
        return out;
    }

    private class OpForDimTaskFJ
    extends RecursiveTask<Double>
    implements Task<Double> {
        private int tensorNum;
        private BaseCPUTask<Double> subTask;
        private Future<Double> future;

        public OpForDimTaskFJ(int tensorNum) {
            this.tensorNum = tensorNum;
        }

        @Override
        public Double invokeBlocking() {
            this.invokeAsync();
            return this.blockUntilComplete();
        }

        @Override
        public void invokeAsync() {
            this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
        }

        @Override
        public Double blockUntilComplete() {
            return null;
        }

        @Override
        public Double call() {
            throw new RuntimeException("Callable.call() called as part of ForkJoin task");
        }

        @Override
        protected Double compute() {
            Accumulation opOnDimension = (Accumulation)CPUAccumulationAlongDimensionTask.this.op.opForDimension(this.tensorNum, CPUAccumulationAlongDimensionTask.this.dimensions);
            INDArray x2 = opOnDimension.x();
            INDArray y2 = opOnDimension.y();
            boolean canDoDirectly = y2 == null ? OpExecutionerUtil.canDoOpDirectly(x2) : OpExecutionerUtil.canDoOpDirectly(x2, y2);
            this.subTask = canDoDirectly ? new CPUAccumulationTask(opOnDimension, CPUAccumulationAlongDimensionTask.this.threshold, true) : new CPUAccumulationViaTensorTask(opOnDimension, CPUAccumulationAlongDimensionTask.this.threshold, true);
            return (Double)this.subTask.invoke();
        }
    }

    private class OpForDimTask
    extends BaseTask<Double> {
        private int tensorNum;
        private BaseCPUTask<Double> subTask;
        private Future<Double> future;

        public OpForDimTask(int tensorNum) {
            this.tensorNum = tensorNum;
        }

        @Override
        public void invokeAsync() {
            this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
        }

        @Override
        public Double blockUntilComplete() {
            try {
                this.future.get();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return (Double)this.subTask.blockUntilComplete();
        }

        @Override
        public Double call() {
            Accumulation opOnDimension = (Accumulation)CPUAccumulationAlongDimensionTask.this.op.opForDimension(this.tensorNum, CPUAccumulationAlongDimensionTask.this.dimensions);
            INDArray x2 = opOnDimension.x();
            INDArray y2 = opOnDimension.y();
            boolean canDoDirectly = y2 == null ? OpExecutionerUtil.canDoOpDirectly(x2) : OpExecutionerUtil.canDoOpDirectly(x2, y2);
            this.subTask = canDoDirectly ? new CPUAccumulationTask(opOnDimension, CPUAccumulationAlongDimensionTask.this.threshold, true) : new CPUAccumulationViaTensorTask(opOnDimension, CPUAccumulationAlongDimensionTask.this.threshold, true);
            this.subTask.invokeAsync();
            return null;
        }
    }
}

