/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.vector;

import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;
import org.nd4j.linalg.util.ArrayUtil;

public class CpuBroadcastOp
extends BaseCPUAction {
    protected final BroadcastOp op;

    public CpuBroadcastOp(BroadcastOp op, int threshold) {
        super(op, threshold);
        this.op = op;
    }

    @Override
    public Void blockUntilComplete() {
        if (this.future == null) {
            this.invokeAsync();
        }
        try {
            this.future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (this.subTasks != null) {
            for (Task task : this.subTasks) {
                task.blockUntilComplete();
            }
        }
        return null;
    }

    @Override
    public Void call() {
        int nVectorOps;
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z = this.op.z();
        if (x.rank() == 2) {
            nVectorOps = this.op.getDimension()[0] == 0 ? x.columns() : x.rows();
        } else {
            int[] shape = x.shape();
            nVectorOps = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])shape, (int[])this.op.getDimension()));
        }
        this.subTasks = new ArrayList(nVectorOps);
        int[] dimension = this.op.getDimension();
        if (x.size(dimension[0]) != y.length()) {
            throw new UnsupportedOperationException("Array length " + y.length() + " does not match x.shape(" + dimension + ")=" + x.size(dimension[0]));
        }
        if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats t1dx = OpExecutionerUtil.get1DTensorStats(x, dimension);
            if (y != null) {
                int offsetY = y.offset();
                int ewsy = y.elementWiseStride();
                if (x == z) {
                    for (int i = 0; i < nVectorOps; ++i) {
                        int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                        SingleVectorAction task = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, offsetY, offsetX, t1dx.getElementWiseStride(), ewsy, t1dx.getElementWiseStride());
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                } else {
                    OpExecutionerUtil.Tensor1DStats t1dz = OpExecutionerUtil.get1DTensorStats(z, dimension);
                    for (int i = 0; i < nVectorOps; ++i) {
                        int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                        int offsetZ = t1dz.getFirstTensorOffset() + i * t1dz.getTensorStartSeparation();
                        SingleVectorAction task = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, offsetY, offsetZ, t1dx.getElementWiseStride(), ewsy, t1dz.getElementWiseStride());
                        task.invokeAsync();
                        this.subTasks.add(task);
                    }
                }
            } else if (x == z) {
                for (int i = 0; i < nVectorOps; ++i) {
                    int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                    SingleVectorAction task = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, 0, offsetX, t1dx.getElementWiseStride(), 0, t1dx.getElementWiseStride());
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            } else {
                OpExecutionerUtil.Tensor1DStats t1dz = OpExecutionerUtil.get1DTensorStats(z, dimension);
                for (int i = 0; i < nVectorOps; ++i) {
                    int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                    int offsetZ = t1dz.getFirstTensorOffset() + i * t1dz.getTensorStartSeparation();
                    SingleVectorAction task = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, 0, offsetZ, t1dx.getElementWiseStride(), 0, t1dz.getElementWiseStride());
                    task.invokeAsync();
                    this.subTasks.add(task);
                }
            }
        } else {
            for (int i = 0; i < nVectorOps; ++i) {
                SingleVectorAction task = new SingleVectorAction(this.threshold, i, dimension[0]);
                task.invokeAsync();
                this.subTasks.add(task);
            }
        }
        return null;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    protected void compute() {
        int nVectorOps;
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z = this.op.z();
        if (x.rank() == 2) {
            nVectorOps = this.op.getDimension()[0] == 0 ? x.columns() : x.rows();
        } else {
            int[] shape = x.shape();
            nVectorOps = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])shape, (int[])this.op.getDimension()));
        }
        ArrayList<SingleVectorAction> subTasks = new ArrayList<SingleVectorAction>(nVectorOps);
        int[] dimension = this.op.getDimension();
        if (x.size(dimension[0]) != y.length()) {
            throw new UnsupportedOperationException("Vector length " + y.length() + " does not match x.shape(" + dimension[0] + ")= " + x.size(dimension[0]));
        }
        if (x.rank() == 2) {
            SingleVectorAction task;
            OpExecutionerUtil.Tensor1DStats t1dx = OpExecutionerUtil.get1DTensorStats(x, dimension);
            if (y != null) {
                int n = y.offset();
                int ewsy = y.elementWiseStride();
                if (x == z) {
                    for (int i = 0; i < nVectorOps; ++i) {
                        int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                        task = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, n, offsetX, t1dx.getElementWiseStride(), ewsy, t1dx.getElementWiseStride());
                        task.fork();
                        subTasks.add(task);
                    }
                } else {
                    OpExecutionerUtil.Tensor1DStats t1dz = OpExecutionerUtil.get1DTensorStats(z, dimension);
                    for (int i = 0; i < nVectorOps; ++i) {
                        int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                        int offsetZ = t1dz.getFirstTensorOffset() + i * t1dz.getTensorStartSeparation();
                        SingleVectorAction task2 = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, n, offsetZ, t1dx.getElementWiseStride(), ewsy, t1dz.getElementWiseStride());
                        task2.fork();
                        subTasks.add(task2);
                    }
                }
            } else if (x == z) {
                void var8_11;
                boolean bl = false;
                while (var8_11 < nVectorOps) {
                    int offsetX = t1dx.getFirstTensorOffset() + var8_11 * t1dx.getTensorStartSeparation();
                    SingleVectorAction task3 = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, 0, offsetX, t1dx.getElementWiseStride(), 0, t1dx.getElementWiseStride());
                    task3.fork();
                    subTasks.add(task3);
                    ++var8_11;
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tensor1DStats = OpExecutionerUtil.get1DTensorStats(z, dimension);
                for (int i = 0; i < nVectorOps; ++i) {
                    int offsetX = t1dx.getFirstTensorOffset() + i * t1dx.getTensorStartSeparation();
                    int offsetZ = tensor1DStats.getFirstTensorOffset() + i * tensor1DStats.getTensorStartSeparation();
                    task = new SingleVectorAction(this.threshold, t1dx.getTensorLength(), offsetX, 0, offsetZ, t1dx.getElementWiseStride(), 0, tensor1DStats.getElementWiseStride());
                    task.fork();
                    subTasks.add(task);
                }
            }
        } else {
            for (int i = 0; i < nVectorOps; ++i) {
                SingleVectorAction singleVectorAction = new SingleVectorAction(this.threshold, i, dimension[0]);
                singleVectorAction.fork();
                subTasks.add(singleVectorAction);
            }
        }
        for (RecursiveAction recursiveAction : subTasks) {
            recursiveAction.join();
        }
    }

    private class SingleVectorAction
    extends BaseCPUAction {
        private SingleVectorAction(int threshold, int n, int offsetX, int offsetY, int offsetZ, int incrX, int incrY, int incrZ) {
            super(threshold, n, offsetX, offsetY, offsetZ, incrX, incrY, incrZ);
        }

        private SingleVectorAction(int threshold, int tadIdx, int tadDim) {
            super(CpuBroadcastOp.this.op, threshold, tadIdx, tadDim);
        }

        @Override
        public void invokeAsync() {
            this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
        }

        @Override
        protected void compute() {
            if (this.doTensorFirst) {
                this.doTensorFirst(CpuBroadcastOp.this.op);
            }
            if (this.n > this.threshold) {
                int nFirst = this.n / 2;
                SingleVectorAction first = new SingleVectorAction(this.threshold, nFirst, this.offsetX, this.offsetY, this.offsetZ, this.incrX, this.incrY, this.incrZ);
                first.fork();
                int nSecond = this.n - nFirst;
                int offsetX2 = this.offsetX + nFirst * this.incrX;
                int offsetY2 = this.offsetY + nFirst * this.incrY;
                int offsetZ2 = this.offsetZ + nFirst * this.incrZ;
                SingleVectorAction second = new SingleVectorAction(this.threshold, nSecond, offsetX2, offsetY2, offsetZ2, this.incrX, this.incrY, this.incrZ);
                second.fork();
                first.join();
                second.join();
            } else {
                this.execute();
            }
        }

        @Override
        public Void blockUntilComplete() {
            if (this.future != null) {
                try {
                    this.future.get();
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } else {
                for (Task t : this.subTasks) {
                    t.blockUntilComplete();
                }
            }
            return null;
        }

        @Override
        public Void call() {
            if (this.n > this.threshold) {
                if (this.doTensorFirst) {
                    this.doTensorFirst(CpuBroadcastOp.this.op);
                }
                int nSubTasks = 1 + this.n / this.threshold;
                this.subTasks = new ArrayList(nSubTasks);
                int taskSize = this.n / nSubTasks;
                int soFar = 0;
                for (int i = 0; i < nSubTasks; ++i) {
                    int nInTask = i == nSubTasks - 1 ? this.n - soFar : taskSize;
                    int offsetXNew = this.offsetX + soFar * this.incrX;
                    int offsetYNew = this.offsetY + soFar * this.incrY;
                    int offsetZNew = this.offsetZ + soFar * this.incrZ;
                    SingleVectorAction task = new SingleVectorAction(this.threshold, nInTask, offsetXNew, offsetYNew, offsetZNew, this.incrX, this.incrY, this.incrZ);
                    task.invokeAsync();
                    this.subTasks.add(task);
                    soFar += nInTask;
                }
            } else {
                this.execute();
            }
            return null;
        }

        @Override
        public void doTensorFirst(Op op) {
            INDArray x = op.x();
            INDArray y = op.y();
            INDArray z = op.z();
            INDArray tadx = x.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.n = tadx.length();
            this.offsetX = tadx.offset();
            this.incrX = tadx.elementWiseStride();
            if (y == null) {
                this.offsetY = 0;
                this.incrY = 0;
            } else {
                this.offsetY = y.offset();
                this.incrY = y.elementWiseStride();
            }
            if (z == null) {
                this.offsetZ = 0;
                this.incrZ = 0;
            } else if (z == x) {
                this.offsetZ = this.offsetX;
                this.incrZ = this.incrX;
            } else {
                INDArray tadz = z.tensorAlongDimension(this.tensorIdx, this.tensorDim);
                this.offsetZ = tadz.offset();
                this.incrZ = tadz.elementWiseStride();
            }
        }

        private Void execute() {
            DataBuffer x = CpuBroadcastOp.this.op.x().data();
            DataBuffer y = CpuBroadcastOp.this.op.y().data();
            DataBuffer z = CpuBroadcastOp.this.op.z().data();
            if (x.allocationMode() == DataBuffer.AllocationMode.HEAP) {
                if (x.dataType() == DataBuffer.Type.FLOAT) {
                    float[] xf = (float[])x.array();
                    float[] yf = (float[])y.array();
                    if (this.incrX == 1 && this.incrY == 1 && (x == z || this.incrZ == 1)) {
                        if (x == z) {
                            for (int i = 0; i < this.n; ++i) {
                                int xIdx = this.offsetX + i;
                                xf[xIdx] = CpuBroadcastOp.this.op.op(xf[xIdx], yf[this.offsetY + i]);
                            }
                        } else {
                            float[] zf = (float[])z.array();
                            for (int i = 0; i < this.n; ++i) {
                                zf[this.offsetZ + i] = CpuBroadcastOp.this.op.op(xf[this.offsetX + i], yf[this.offsetY + i]);
                            }
                        }
                    } else if (x == z) {
                        for (int i = 0; i < this.n; ++i) {
                            int xIdx = this.offsetX + i * this.incrX;
                            xf[xIdx] = CpuBroadcastOp.this.op.op(xf[xIdx], yf[this.offsetY + i * this.incrY]);
                        }
                    } else {
                        float[] zf = (float[])z.array();
                        for (int i = 0; i < this.n; ++i) {
                            zf[this.offsetZ + i * this.incrZ] = CpuBroadcastOp.this.op.op(xf[this.offsetX + i * this.incrX], yf[this.offsetY + i * this.incrY]);
                        }
                    }
                } else {
                    double[] xd = (double[])x.array();
                    double[] yd = (double[])y.array();
                    if (this.incrX == 1 && this.incrY == 1 && (x == z || this.incrZ == 1)) {
                        if (x == z) {
                            for (int i = 0; i < this.n; ++i) {
                                int xIdx = this.offsetX + i;
                                xd[xIdx] = CpuBroadcastOp.this.op.op(xd[xIdx], yd[this.offsetY + i]);
                            }
                        } else {
                            double[] zd = (double[])z.array();
                            for (int i = 0; i < this.n; ++i) {
                                zd[this.offsetZ + i] = CpuBroadcastOp.this.op.op(xd[this.offsetX + i], yd[this.offsetY + i]);
                            }
                        }
                    } else if (x == z) {
                        for (int i = 0; i < this.n; ++i) {
                            int xIdx = this.offsetX + i * this.incrX;
                            xd[xIdx] = CpuBroadcastOp.this.op.op(xd[xIdx], yd[this.offsetY + i * this.incrY]);
                        }
                    } else {
                        double[] zd = (double[])z.array();
                        for (int i = 0; i < this.n; ++i) {
                            zd[this.offsetZ + i * this.incrZ] = CpuBroadcastOp.this.op.op(xd[this.offsetX + i * this.incrX], yd[this.offsetY + i * this.incrY]);
                        }
                    }
                }
            } else {
                ByteBuf nbbx = x.asNetty();
                ByteBuf nbby = y.asNetty();
                ByteBuf nbbz = z.asNetty();
                if (x.dataType() == DataBuffer.Type.FLOAT) {
                    int byteOffsetX = 4 * this.offsetX;
                    int byteOffsetY = 4 * this.offsetY;
                    int byteOffsetZ = 4 * this.offsetZ;
                    if (this.incrX == 1 && this.incrY == 1 && (x == z || this.incrZ == 1)) {
                        if (x == z) {
                            for (int i = 0; i < 4 * this.n; i += 4) {
                                int xbOffset = byteOffsetX + i;
                                nbbx.setFloat(xbOffset, CpuBroadcastOp.this.op.op(nbbx.getFloat(xbOffset), nbby.getFloat(byteOffsetY + i)));
                            }
                        } else {
                            for (int i = 0; i < 4 * this.n; i += 4) {
                                nbbz.setFloat(byteOffsetZ + i, CpuBroadcastOp.this.op.op(nbbx.getFloat(byteOffsetX + i), nbby.getFloat(byteOffsetY + i)));
                            }
                        }
                    } else if (x == z) {
                        for (int i = 0; i < 4 * this.n; i += 4) {
                            int xbOffset = byteOffsetX + i * this.incrX;
                            nbbx.setFloat(xbOffset, CpuBroadcastOp.this.op.op(nbbx.getFloat(xbOffset), nbby.getFloat(byteOffsetY + i * this.incrY)));
                        }
                    } else {
                        for (int i = 0; i < 4 * this.n; i += 4) {
                            nbbz.setFloat(byteOffsetZ + i * this.incrZ, CpuBroadcastOp.this.op.op(nbbx.getFloat(byteOffsetX + i * this.incrX), nbby.getFloat(byteOffsetY + i * this.incrY)));
                        }
                    }
                } else {
                    int byteOffsetX = 8 * this.offsetX;
                    int byteOffsetY = 8 * this.offsetY;
                    int byteOffsetZ = 8 * this.offsetZ;
                    if (this.incrX == 1 && this.incrY == 1 && (x == z || this.incrZ == 1)) {
                        if (x == z) {
                            for (int i = 0; i < 8 * this.n; i += 8) {
                                int xbOffset = byteOffsetX + i;
                                nbbx.setDouble(xbOffset, CpuBroadcastOp.this.op.op(nbbx.getDouble(xbOffset), nbby.getDouble(byteOffsetY + i)));
                            }
                        } else {
                            for (int i = 0; i < 8 * this.n; i += 8) {
                                nbbz.setDouble(byteOffsetZ + i, CpuBroadcastOp.this.op.op(nbbx.getDouble(byteOffsetX + i), nbby.getDouble(byteOffsetY + i)));
                            }
                        }
                    } else if (x == z) {
                        for (int i = 0; i < 8 * this.n; i += 8) {
                            int xbOffset = byteOffsetX + i * this.incrX;
                            nbbx.setDouble(xbOffset, CpuBroadcastOp.this.op.op(nbbx.getDouble(xbOffset), nbby.getDouble(byteOffsetY + i * this.incrY)));
                        }
                    } else {
                        for (int i = 0; i < 8 * this.n; i += 8) {
                            nbbz.setDouble(byteOffsetZ + i * this.incrZ, CpuBroadcastOp.this.op.op(nbbx.getDouble(byteOffsetX + i * this.incrX), nbby.getDouble(byteOffsetY + i * this.incrY)));
                        }
                    }
                }
            }
            return null;
        }
    }
}

