/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops;

import java.util.Arrays;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class TadCollapseAccumulation
extends BaseOp {
    protected Op accum;
    protected boolean performSmallerDimension;
    protected int[] smallerDimension;
    protected int[] originalDimension;
    protected int tadsForSmallerDimension;
    protected int tadsForLargerDimension;
    public static final String DEFAULT_NAME = "collapseTad";

    public TadCollapseAccumulation() {
    }

    public TadCollapseAccumulation(Op accum, int[] originalDimension, int[] smallerDimension, boolean performSmallerDimension) {
        this.accum = accum;
        this.performSmallerDimension = performSmallerDimension;
        this.originalDimension = originalDimension;
        this.smallerDimension = smallerDimension;
        this.tadsForSmallerDimension = accum.x().tensorssAlongDimension(smallerDimension);
        this.tadsForLargerDimension = accum.x().tensorssAlongDimension(originalDimension);
    }

    public TadCollapseAccumulation(Op accum, int[] originalDimension, int[] smallerDimension) {
        this(accum, originalDimension, smallerDimension, true);
    }

    public TadCollapseAccumulation(Op accum, int[] originalDimension) {
        this.accum = accum;
        this.originalDimension = originalDimension;
    }

    public TadCollapseAccumulation(Op accum) {
        this.accum = accum;
    }

    public TadCollapseAccumulation(INDArray x, Op accum) {
        super(x);
        this.accum = accum;
    }

    public TadCollapseAccumulation(INDArray x, INDArray y, INDArray z, long n, Op accum) {
        super(x, y, z, n);
        this.accum = accum;
    }

    public TadCollapseAccumulation(INDArray x, INDArray z, Op accum) {
        super(x, z);
        this.accum = accum;
    }

    public TadCollapseAccumulation(INDArray x, INDArray z, long n, Op accum) {
        super(x, z, n);
        this.accum = accum;
    }

    public Op getAccum() {
        return this.accum;
    }

    @Override
    public boolean isPassThrough() {
        return true;
    }

    @Override
    public void exec() {
        Op acc2;
        if (this.smallerDimension == null) {
            this.smallerDimension = new int[]{this.originalDimension[this.originalDimension.length - 1]};
        }
        if (this.accum instanceof Accumulation && this.performSmallerDimension) {
            acc2 = (Accumulation)this.accum;
            acc2.setApplyFinalTransform(false);
            Nd4j.getExecutioner().exec((Accumulation)acc2, this.smallerDimension);
        } else if (this.accum instanceof IndexAccumulation && this.performSmallerDimension) {
            acc2 = (IndexAccumulation)this.accum;
            Nd4j.getExecutioner().exec((IndexAccumulation)acc2, this.smallerDimension);
        }
        INDArray aggregated = Nd4j.create(ArrayUtil.removeIndex((int[])this.accum.x().shape(), (int[])this.originalDimension));
        int smallerProblem = this.accum.x().tensorssAlongDimension(this.smallerDimension);
        int biggerProblem = this.accum.x().tensorssAlongDimension(this.originalDimension);
        if (this.accum instanceof Accumulation) {
            int i;
            int biggerTadLength = this.accum.x().tensorAlongDimension(0, this.originalDimension).length();
            Accumulation accumulation = (Accumulation)this.accum;
            for (i = 0; i < smallerProblem; ++i) {
                int reductionIndex = TadCollapseAccumulation.reductionIndexForTad(i, biggerProblem, smallerProblem);
                aggregated.putScalar(reductionIndex, accumulation.combineSubResults(aggregated.getDouble(reductionIndex), accumulation.z().getDouble(i)));
            }
            this.accum.setN(biggerTadLength);
            accumulation.setApplyFinalTransform(true);
            for (i = 0; i < aggregated.length(); ++i) {
                aggregated.putScalar(i, accumulation.calculateFinalResult(aggregated.getDouble(i), (long)biggerTadLength));
            }
        } else if (this.accum instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)this.accum;
            for (int i = 0; i < smallerProblem; ++i) {
                int reductionIndex = TadCollapseAccumulation.reductionIndexForTad(i, biggerProblem, smallerProblem);
                aggregated.putScalar(reductionIndex, indexAccumulation.combineSubResults(this.accum.x().getDouble(i), i, aggregated.getDouble(reductionIndex), reductionIndex));
            }
        }
        this.accum.setZ(aggregated);
    }

    @Override
    public INDArray x() {
        return this.accum.x();
    }

    @Override
    public INDArray y() {
        return this.accum.y();
    }

    @Override
    public INDArray z() {
        return this.accum.z();
    }

    @Override
    public void exec(int ... dimensions) {
        this.originalDimension = dimensions;
        this.exec();
    }

    @Override
    public int opNum() {
        return 0;
    }

    @Override
    public String name() {
        if (this.accum == null) {
            return DEFAULT_NAME;
        }
        return this.accum.name();
    }

    @Override
    public IComplexNumber op(IComplexNumber origin, double other) {
        return this.accum.op(origin, other);
    }

    @Override
    public IComplexNumber op(IComplexNumber origin, float other) {
        return this.accum.op(origin, other);
    }

    @Override
    public IComplexNumber op(IComplexNumber origin, IComplexNumber other) {
        return this.accum.op(origin, other);
    }

    @Override
    public float op(float origin, float other) {
        return this.accum.op(origin, other);
    }

    @Override
    public double op(double origin, double other) {
        return this.accum.op(origin, other);
    }

    @Override
    public double op(double origin) {
        return this.accum.op(origin);
    }

    @Override
    public float op(float origin) {
        return this.accum.op(origin);
    }

    @Override
    public IComplexNumber op(IComplexNumber origin) {
        return this.accum.op(origin);
    }

    @Override
    public Op opForDimension(int index, int dimension) {
        return this.accum.opForDimension(index, dimension);
    }

    @Override
    public Op opForDimension(int index, int ... dimension) {
        return this.accum.opForDimension(index, dimension);
    }

    public static int tadIndex(int i, int elementWiseStride, int numElementsPerTad) {
        return i / (numElementsPerTad * elementWiseStride);
    }

    public static int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, int tadsForOriginal) {
        if (tadIndexForOriginal == 0) {
            return 0;
        }
        return tadIndexForOriginal / (tadsForOriginal / tadsForReduced);
    }

    public static int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal) {
        return tadsForOriginal / tadsForReduce;
    }

    public static int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, int tadNum, int originalTadNum) {
        int tad = TadCollapseAccumulation.tadIndex(i, elementWiseStride, numElementsPerTad);
        return TadCollapseAccumulation.reductionIndexForTad(tad, tadNum, originalTadNum);
    }

    public boolean isPerformSmallerDimension() {
        return this.performSmallerDimension;
    }

    public int[] getSmallerDimension() {
        return this.smallerDimension;
    }

    public int[] getOriginalDimension() {
        return this.originalDimension;
    }

    public int getTadsForSmallerDimension() {
        return this.tadsForSmallerDimension;
    }

    public int getTadsForLargerDimension() {
        return this.tadsForLargerDimension;
    }

    public void setAccum(Op accum) {
        this.accum = accum;
    }

    public void setPerformSmallerDimension(boolean performSmallerDimension) {
        this.performSmallerDimension = performSmallerDimension;
    }

    public void setSmallerDimension(int[] smallerDimension) {
        this.smallerDimension = smallerDimension;
    }

    public void setOriginalDimension(int[] originalDimension) {
        this.originalDimension = originalDimension;
    }

    public void setTadsForSmallerDimension(int tadsForSmallerDimension) {
        this.tadsForSmallerDimension = tadsForSmallerDimension;
    }

    public void setTadsForLargerDimension(int tadsForLargerDimension) {
        this.tadsForLargerDimension = tadsForLargerDimension;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TadCollapseAccumulation)) {
            return false;
        }
        TadCollapseAccumulation other = (TadCollapseAccumulation)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Op this$accum = this.getAccum();
        Op other$accum = other.getAccum();
        if (this$accum == null ? other$accum != null : !this$accum.equals(other$accum)) {
            return false;
        }
        if (this.isPerformSmallerDimension() != other.isPerformSmallerDimension()) {
            return false;
        }
        if (!Arrays.equals(this.getSmallerDimension(), other.getSmallerDimension())) {
            return false;
        }
        if (!Arrays.equals(this.getOriginalDimension(), other.getOriginalDimension())) {
            return false;
        }
        if (this.getTadsForSmallerDimension() != other.getTadsForSmallerDimension()) {
            return false;
        }
        return this.getTadsForLargerDimension() == other.getTadsForLargerDimension();
    }

    protected boolean canEqual(Object other) {
        return other instanceof TadCollapseAccumulation;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Op $accum = this.getAccum();
        result = result * 59 + ($accum == null ? 0 : $accum.hashCode());
        result = result * 59 + (this.isPerformSmallerDimension() ? 79 : 97);
        result = result * 59 + Arrays.hashCode(this.getSmallerDimension());
        result = result * 59 + Arrays.hashCode(this.getOriginalDimension());
        result = result * 59 + this.getTadsForSmallerDimension();
        result = result * 59 + this.getTadsForLargerDimension();
        return result;
    }

    @Override
    public String toString() {
        return "TadCollapseAccumulation(accum=" + this.getAccum() + ", performSmallerDimension=" + this.isPerformSmallerDimension() + ", smallerDimension=" + Arrays.toString(this.getSmallerDimension()) + ", originalDimension=" + Arrays.toString(this.getOriginalDimension()) + ", tadsForSmallerDimension=" + this.getTadsForSmallerDimension() + ", tadsForLargerDimension=" + this.getTadsForLargerDimension() + ")";
    }
}

