/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.executioner;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.complex.LinearViewComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.LinearViewNDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class DefaultOpExecutioner
implements OpExecutioner {
    protected OpExecutioner.ExecutionMode executionMode = OpExecutioner.ExecutionMode.JAVA;

    @Override
    public Op exec(Op op) {
        block10: {
            block11: {
                block9: {
                    this.checkOp(op);
                    if (op.isPassThrough()) {
                        op.exec();
                        return op;
                    }
                    if (!(op instanceof TransformOp)) break block9;
                    TransformOp t = (TransformOp)op;
                    if (!(op.x().getClass().equals(t.z().getClass()) || op.x() instanceof LinearViewNDArray || t.z() instanceof LinearViewNDArray)) {
                        throw new IllegalArgumentException("Illegal operation. Origin and output ndarray must be same types. op.x was " + op.x().getClass().getName() + " while t.z was " + t.z().getClass().getName());
                    }
                    for (int c = 0; c < op.n(); ++c) {
                        this.apply(t, c);
                    }
                    break block10;
                }
                if (!(op instanceof Accumulation)) break block11;
                Accumulation accumulation = (Accumulation)op;
                for (int c = 0; c < op.n(); ++c) {
                    this.apply(accumulation, c);
                }
                break block10;
            }
            if (!(op instanceof ScalarOp)) break block10;
            ScalarOp scalarOp = (ScalarOp)op;
            if (op.isPassThrough()) {
                return scalarOp;
            }
            INDArray zLinear = op.z().linearView();
            INDArray xLinear = op.x().linearView();
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray ndArray = (IComplexNDArray)op.z();
                for (int c = 0; c < op.n(); ++c) {
                    ndArray.putScalar(c, op.op(((IComplexNDArray)op.x()).getComplex(c)));
                }
            } else {
                for (int c = 0; c < op.n(); ++c) {
                    zLinear.putScalar(c, op.op(xLinear.getDouble(c)));
                }
            }
        }
        return op;
    }

    @Override
    public INDArray execAndReturn(Op op) {
        if (op instanceof TransformOp) {
            return this.execAndReturn((TransformOp)op);
        }
        if (op instanceof ScalarOp) {
            return this.execAndReturn((ScalarOp)op);
        }
        if (op instanceof Accumulation) {
            return Nd4j.scalar(this.execAndReturn((Accumulation)op).currentResult());
        }
        throw new IllegalArgumentException("Illegal type of op " + op.getClass());
    }

    @Override
    public void iterateOverAllRows(Op op) {
        if (op.x().isVector()) {
            op.setX(op.x());
            if (op.y() != null) {
                op.setY(op.y());
            }
            op.setZ(op.z());
            this.exec(op);
        } else if (op.x().isMatrix()) {
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray original = (IComplexNDArray)op.x();
                IComplexNDArray originalZ = (IComplexNDArray)op.z();
                IComplexNDArray y = (IComplexNDArray)op.y();
                for (int i = 0; i < original.rows(); ++i) {
                    IComplexNDArray row = original.slice(i);
                    IComplexNDArray zRow = originalZ.slice(i);
                    op.setX(row.dup());
                    op.setZ(zRow.dup());
                    if (y != null) {
                        op.setY(y.slice(i));
                    }
                    this.exec(op);
                    originalZ.slice(i).assign(op.z());
                }
            } else {
                INDArray original = op.x();
                INDArray originalZ = op.z();
                INDArray y = op.y();
                for (int i = 0; i < original.rows(); ++i) {
                    INDArray row = original.getRow(i);
                    INDArray zRow = originalZ.getRow(i);
                    op.setX(row.dup());
                    op.setZ(zRow.dup());
                    if (y != null) {
                        op.setY(y.getRow(i).dup());
                    }
                    this.exec(op);
                    zRow.assign(op.z());
                }
            }
        } else {
            INDArray originalX = op.x();
            INDArray originalZ = op.z();
            for (int i = 0; i < originalX.slices(); ++i) {
                INDArray slice = originalX.slice(i);
                INDArray zSlice = originalZ.slice(i);
                op.setX(slice);
                op.setZ(zSlice);
                this.iterateOverAllRows(op);
            }
        }
    }

    @Override
    public void iterateOverAllColumns(Op op) {
        if (op.x().isVector()) {
            this.exec(op);
        } else if (op.x().isMatrix() || op.x().isColumnVector()) {
            this.exec(op, 1);
        } else if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray originalX = (IComplexNDArray)op.x();
            IComplexNDArray originalZ = (IComplexNDArray)op.z();
            IComplexNDArray y = (IComplexNDArray)op.y();
            for (int i = 0; i < op.x().slices(); ++i) {
                op.setX(originalX.getColumn(i));
                op.setZ(originalZ.getColumn(i));
                if (y != null) {
                    op.setY(y.getColumn(i));
                }
                this.iterateOverAllColumns(op);
            }
        } else {
            INDArray originalX = op.x();
            INDArray originalZ = op.z();
            INDArray y = op.y();
            for (int i = 0; i < op.x().slices(); ++i) {
                op.setX(originalX.getColumn(i));
                op.setZ(originalZ.getColumn(i));
                if (y != null) {
                    op.setY(y.getColumn(i));
                }
                this.iterateOverAllColumns(op);
            }
        }
    }

    @Override
    public INDArray execAndReturn(TransformOp op) {
        Op result = this.exec(op);
        TransformOp t = (TransformOp)result;
        return t.z();
    }

    @Override
    public Accumulation execAndReturn(Accumulation op) {
        return (Accumulation)this.exec(op);
    }

    @Override
    public INDArray execAndReturn(ScalarOp op) {
        return this.exec(op).z();
    }

    @Override
    public Op exec(Op op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op;
        }
        if (dimension.length == 1) {
            return this.exec(op, dimension[0]);
        }
        if (op instanceof Accumulation) {
            Accumulation a = (Accumulation)op;
            return this.exec(a);
        }
        for (int i = 0; i < op.x().tensorssAlongDimension(dimension); ++i) {
            Op op2 = op.opForDimension(i, dimension);
            this.exec(op2);
            if (!(op instanceof TransformOp)) continue;
            TransformOp t = (TransformOp)op;
            TransformOp t2 = (TransformOp)op2;
            t.z().tensorAlongDimension(i, dimension).assign(t2.z());
        }
        return op;
    }

    protected Op exec(Op op, int dimension) {
        if (op.isPassThrough()) {
            op.exec();
            return op;
        }
        if (op instanceof Accumulation) {
            Accumulation a = (Accumulation)op;
            return this.exec(a);
        }
        for (int i = 0; i < op.x().vectorsAlongDimension(dimension); ++i) {
            Op op2 = op.opForDimension(i, dimension);
            this.exec(op2);
            if (!(op instanceof TransformOp)) continue;
            TransformOp t = (TransformOp)op;
            TransformOp t2 = (TransformOp)op2;
            t.z().vectorAlongDimension(i, dimension).assign(t2.z());
        }
        return op;
    }

    protected void checkOp(Op op) {
        if (op.x() instanceof LinearViewNDArray || op.y() != null && op.y() instanceof LinearViewNDArray || op.z() != null && op.z() instanceof LinearViewNDArray || op.x() != null && op.x() instanceof LinearViewComplexNDArray || op.y() != null && op.y() instanceof LinearViewComplexNDArray || op.z() != null && op.z() instanceof LinearViewComplexNDArray || op.x() != null && op.x().isScalar() || op.y() != null && op.y().isScalar() || op.z() != null && op.z().isScalar()) {
            return;
        }
    }

    @Override
    public INDArray exec(Accumulation op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op.z();
        }
        if (dimension[0] == Integer.MAX_VALUE) {
            if (op.x() instanceof IComplexNDArray) {
                return Nd4j.scalar(this.execAndReturn(op).currentResultComplex());
            }
            return Nd4j.scalar(this.execAndReturn(op).currentResult().doubleValue());
        }
        int[] retShape = ArrayUtil.removeIndex(op.x().shape(), dimension);
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op instanceof IComplexNDArray) {
            IComplexNDArray ret = Nd4j.createComplex(retShape);
            IComplexNDArray linear = ret.linearView();
            for (int i = 0; i < op.x().tensorssAlongDimension(dimension); ++i) {
                Op op2 = op.opForDimension(i, dimension);
                IComplexNumber result = this.execAndReturn((Accumulation)op2).currentResultComplex();
                linear.putScalar(i, result);
            }
            if (ret.ordering() == 'c') {
                ret.setStride(ArrayUtil.reverseCopy(ret.stride()));
            }
            return ret;
        }
        INDArray ret = Nd4j.create(retShape);
        INDArray linear = ret.linearView();
        for (int i = 0; i < op.x().tensorssAlongDimension(dimension); ++i) {
            Op op2 = op.opForDimension(i, dimension);
            double result = this.execAndReturn((Accumulation)op2).currentResult().doubleValue();
            linear.putScalar(i, result);
        }
        return ret;
    }

    protected INDArray execVector(Accumulation op, int dimension) {
        if (op.isPassThrough()) {
            op.exec();
            return op.z();
        }
        if (dimension == Integer.MAX_VALUE) {
            op.setX(op.x().linearView());
            if (op.y() != null) {
                op.setY(op.y().linearView());
            }
            op.setZ(op.z().linearView());
            if (op.x() instanceof IComplexNDArray) {
                return Nd4j.scalar(this.execAndReturn(op).currentResultComplex());
            }
            return Nd4j.scalar(this.execAndReturn(op).currentResult());
        }
        if (op.x().isScalar()) {
            return op.x();
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray ret = Nd4j.createComplex(ArrayUtil.removeIndex(op.x().shape(), dimension));
            IComplexNDArray linear = ret.linearView();
            if (op.x().isRowVector()) {
                if (dimension == 0) {
                    return op.x();
                }
                if (dimension == 1) {
                    return Nd4j.scalar(this.execAndReturn(op).currentResult());
                }
            } else if (op.x().isColumnVector()) {
                if (dimension == 0) {
                    return Nd4j.scalar(this.execAndReturn(op).currentResult());
                }
                if (dimension == 1) {
                    return Nd4j.scalar(this.execAndReturn(op).currentResult());
                }
            }
            for (int i = 0; i < op.x().vectorsAlongDimension(dimension); ++i) {
                Op op2 = op.opForDimension(i, dimension);
                IComplexNumber result = this.execAndReturn((Accumulation)op2).currentResultComplex();
                linear.putScalar(i, result);
            }
            return ret;
        }
        if (op.x().isRowVector()) {
            if (dimension == 0) {
                return op.x();
            }
            if (dimension == 1) {
                return Nd4j.scalar(this.execAndReturn(op).currentResult());
            }
        } else if (op.x().isColumnVector()) {
            if (dimension == 0) {
                return Nd4j.scalar(this.execAndReturn(op).currentResult());
            }
            if (dimension == 1) {
                return op.z().transpose();
            }
        }
        if (op.x().isMatrix() || op.x().isVector()) {
            int[] shape = ArrayUtil.removeIndex(op.x().shape(), dimension);
            if (shape.length < 2) {
                shape = new int[]{1, shape[0]};
            }
            INDArray ret = Nd4j.create(shape);
            INDArray linear = ret.linearView();
            for (int i = 0; i < op.x().vectorsAlongDimension(dimension); ++i) {
                Op op2 = op.opForDimension(i, dimension);
                Number result = this.execAndReturn((Accumulation)op2).currentResult();
                linear.putScalar(i, result.doubleValue());
            }
            return ret;
        }
        int[] shape = ArrayUtil.removeIndex(op.x().shape(), dimension);
        INDArray ret = Nd4j.create(shape);
        INDArray linear = ret.linearView();
        for (int i = 0; i < op.x().vectorsAlongDimension(dimension); ++i) {
            Op op2 = op.opForDimension(i, dimension);
            Number result = this.execAndReturn((Accumulation)op2).currentResult();
            linear.putScalar(i, result.doubleValue());
        }
        return ret;
    }

    @Override
    public INDArray execAndReturn(TransformOp op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (dimension.length == 1) {
            return this.execAndReturnVector(op, dimension[0]);
        }
        for (int i = 0; i < op.x().tensorssAlongDimension(dimension); ++i) {
            Op op2 = op.opForDimension(i, dimension);
            this.exec(op2);
            op.z().tensorAlongDimension(i, dimension).assign(op2.z());
        }
        return op.z();
    }

    protected INDArray execAndReturnVector(TransformOp op, int dimension) {
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op.z();
        }
        for (int i = 0; i < op.x().vectorsAlongDimension(dimension); ++i) {
            Op op2 = op.opForDimension(i, dimension);
            this.exec(op2);
            op.z().vectorAlongDimension(i, dimension).assign(op2.z());
        }
        return op.z();
    }

    @Override
    public INDArray execAndReturn(ScalarOp op, int ... dimension) {
        return this.exec((Op)op, dimension).z();
    }

    @Override
    public OpExecutioner.ExecutionMode executionMode() {
        return this.executionMode;
    }

    @Override
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.executionMode = executionMode;
    }

    private void apply(TransformOp op, int c) {
        if (op.isPassThrough()) {
            return;
        }
        if (op.y() != null) {
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray complexX = (IComplexNDArray)op.x().linearView();
                IComplexNDArray complexZ = (IComplexNDArray)op.z().linearView();
                IComplexNumber curr = complexX.getComplex(c);
                if (op.y() instanceof IComplexNDArray) {
                    IComplexNDArray complexY = (IComplexNDArray)op.y().linearView();
                    complexZ.putScalar(c, op.op(curr, complexY.getComplex(c)));
                } else {
                    complexZ.putScalar(c, op.op(curr, op.y().getDouble(c)));
                }
            } else {
                INDArray zLinear = op.z().linearView();
                INDArray xLinear = op.x().linearView();
                INDArray yLinear = op.y().linearView();
                zLinear.putScalar(c, op.op(xLinear.getDouble(c), yLinear.getDouble(c)));
            }
        } else if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray complexX = (IComplexNDArray)op.x().linearView();
            IComplexNDArray complexZ = (IComplexNDArray)op.z().linearView();
            if (op.y() instanceof IComplexNDArray) {
                complexZ.putScalar(c, op.op(complexX.getComplex(c)));
            } else {
                complexZ.putScalar(c, op.op(complexX.getComplex(c)));
            }
        } else {
            op.z().linearView().putScalar(c, op.op(op.x().linearView().getDouble(c)));
        }
    }

    private void apply(Accumulation op, int x) {
        if (op.isPassThrough()) {
            return;
        }
        if (op.y() != null) {
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray complexX = (IComplexNDArray)op.x().linearView();
                IComplexNDArray complexY = (IComplexNDArray)op.y().linearView();
                IComplexNumber curr = complexX.getComplex(x);
                if (op.y() instanceof IComplexNDArray) {
                    op.update(op.op(curr, complexY.getComplex(x)));
                } else {
                    op.update(op.op(curr, op.y().linearView().getDouble(x)));
                }
            } else {
                op.update(op.op(op.x().linearView().getDouble(x), op.y().linearView().getDouble(x)));
            }
        } else if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray complexX = (IComplexNDArray)op.x().linearView();
            op.update(op.op(complexX.getComplex(x)));
        } else {
            op.update(op.op(op.x().linearView().getDouble(x)));
        }
    }
}

