/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.indexing;

import com.google.common.primitives.Ints;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.IntervalIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.NewAxis;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.util.ArrayUtil;

public class ShapeOffsetResolution
implements Serializable {
    private INDArray arr;
    private int[] offsets;
    private int[] shapes;
    private int[] strides;
    private int offset = -1;

    public ShapeOffsetResolution(INDArray arr) {
        this.arr = arr;
    }

    public boolean tryShortCircuit(INDArrayIndex ... indexes) {
        int minDimensions;
        int pointIndex = 0;
        int interval = 0;
        int newAxis = 0;
        int numAll = 0;
        int numSpecified = 0;
        for (int i = 0; i < indexes.length; ++i) {
            if (indexes[i] instanceof PointIndex) {
                ++pointIndex;
            }
            if (indexes[i] instanceof SpecifiedIndex) {
                ++numSpecified;
                continue;
            }
            if (indexes[i] instanceof IntervalIndex && !(indexes[i] instanceof NDArrayIndexAll)) {
                ++interval;
                continue;
            }
            if (indexes[i] instanceof NewAxis) {
                ++newAxis;
                continue;
            }
            if (!(indexes[i] instanceof NDArrayIndexAll)) continue;
            ++numAll;
        }
        if (numSpecified < 1 && interval < 1 && newAxis < 1 && pointIndex > 0 && numAll > 0) {
            minDimensions = Math.max(this.arr.rank() - pointIndex, 2);
            int[] shape = new int[minDimensions];
            Arrays.fill(shape, 1);
            int[] stride = new int[minDimensions];
            Arrays.fill(stride, this.arr.elementStride());
            int[] offsets = new int[minDimensions];
            int offset = 0;
            int currIndex = 0;
            int arrIndex = 0;
            for (int i = 0; i < indexes.length; ++i) {
                if (indexes[i] instanceof NDArrayIndexAll) {
                    shape[currIndex] = this.arr.size(arrIndex);
                    stride[currIndex] = this.arr.stride(arrIndex);
                    ++currIndex;
                    ++arrIndex;
                    continue;
                }
                offset += indexes[i].offset() * this.arr.stride(i);
                ++arrIndex;
            }
            if (this.arr.isMatrix() && indexes[0] instanceof PointIndex) {
                shape = ArrayUtil.reverseCopy(shape);
                stride = ArrayUtil.reverseCopy(stride);
            }
            this.strides = stride;
            this.shapes = shape;
            this.offsets = offsets;
            this.offset = offset;
            return true;
        }
        if (numSpecified < 1 && interval > 0 && newAxis < 1 && pointIndex < 1 && numAll > 0) {
            int i;
            minDimensions = Math.max(this.arr.rank(), 2);
            int[] shape = new int[minDimensions];
            Arrays.fill(shape, 1);
            int[] stride = new int[minDimensions];
            Arrays.fill(stride, this.arr.elementStride());
            int[] offsets = new int[minDimensions];
            for (i = 0; i < shape.length; ++i) {
                if (indexes[i] instanceof NDArrayIndexAll) {
                    shape[i] = this.arr.size(i);
                    stride[i] = this.arr.stride(i);
                    offsets[i] = indexes[i].offset();
                    continue;
                }
                if (!(indexes[i] instanceof IntervalIndex)) continue;
                shape[i] = indexes[i].length();
                stride[i] = indexes[i].stride() * this.arr.stride(i);
                offsets[i] = indexes[i].offset();
            }
            this.shapes = shape;
            this.strides = stride;
            this.offsets = offsets;
            this.offset = 0;
            for (i = 0; i < indexes.length; ++i) {
                this.offset += offsets[i] * (stride[i] / indexes[i].stride());
            }
            return true;
        }
        if (numSpecified < 1 && interval < 1 && newAxis < 1 && pointIndex < 1 && numAll > 0) {
            minDimensions = Math.max(this.arr.rank(), 2) + newAxis;
            int[] shape = new int[minDimensions];
            Arrays.fill(shape, 1);
            int[] stride = new int[minDimensions];
            Arrays.fill(stride, this.arr.elementStride());
            int[] offsets = new int[minDimensions];
            int prependNewAxes = 0;
            boolean allFirst = false;
            int shapeAxis = 0;
            for (int i = 0; i < indexes.length; ++i) {
                if (indexes[i] instanceof NewAxis) {
                    if (allFirst) {
                        shape[i] = 1;
                        stride[i] = 0;
                        continue;
                    }
                    ++prependNewAxes;
                    continue;
                }
                if (i == 0) {
                    allFirst = true;
                }
                shape[i] = this.arr.size(shapeAxis + prependNewAxes);
                stride[i] = this.arr.stride(shapeAxis + prependNewAxes);
                ++shapeAxis;
            }
            return true;
        }
        return false;
    }

    public void exec(INDArrayIndex ... indexes) {
        boolean needsFilledIn;
        indexes = NDArrayIndex.resolve(this.arr.shape(), indexes);
        if (this.tryShortCircuit(indexes)) {
            return;
        }
        int[] shape = this.arr.shape();
        int numIntervals = 0;
        int newAxesPrepend = 0;
        boolean encounteredAll = false;
        ArrayList<Integer> oneDimensionWithAllEncountered = new ArrayList<Integer>();
        ArrayList<Integer> accumShape = new ArrayList<Integer>();
        ArrayList<Integer> accumStrides = new ArrayList<Integer>();
        ArrayList<Integer> accumOffsets = new ArrayList<Integer>();
        ArrayList<Integer> intervalStrides = new ArrayList<Integer>();
        ArrayList<Integer> pointStrides = new ArrayList<Integer>();
        ArrayList<Integer> pointOffsets = new ArrayList<Integer>();
        int numPointIndexes = 0;
        int shapeIndex = 0;
        int strideIndex = 0;
        ArrayList<Integer> prependNewAxes = new ArrayList<Integer>();
        for (int i = 0; i < indexes.length; ++i) {
            INDArrayIndex idx = indexes[i];
            if (idx instanceof NDArrayIndexAll) {
                encounteredAll = true;
                if (i < this.arr.rank() && this.arr.size(i) == 1) {
                    oneDimensionWithAllEncountered.add(i);
                }
            }
            if (idx instanceof PointIndex) {
                pointOffsets.add(idx.offset());
                pointStrides.add(this.arr.stride(strideIndex));
                ++numPointIndexes;
                ++shapeIndex;
                ++strideIndex;
                continue;
            }
            if (idx instanceof NewAxis) {
                if (encounteredAll) {
                    prependNewAxes.add(i);
                    continue;
                }
                ++newAxesPrepend;
                continue;
            }
            if (idx instanceof IntervalIndex && !(idx instanceof NDArrayIndexAll) || idx instanceof SpecifiedIndex) {
                if (idx instanceof IntervalIndex) {
                    accumStrides.add(this.arr.stride(strideIndex) * idx.stride());
                    intervalStrides.add(idx.stride());
                    ++numIntervals;
                } else {
                    accumStrides.add(this.arr.stride(strideIndex));
                }
                accumShape.add(idx.length());
                if (idx instanceof IntervalIndex) {
                    accumOffsets.add(idx.offset());
                } else {
                    accumOffsets.add(idx.offset());
                }
                ++shapeIndex;
                ++strideIndex;
                continue;
            }
            accumShape.add(shape[shapeIndex++]);
            accumStrides.add(this.arr.stride(strideIndex++));
            accumOffsets.add(idx.offset());
        }
        while (shapeIndex < shape.length) {
            if (Shape.isVector(shape)) {
                accumShape.add(1);
                ++shapeIndex;
                continue;
            }
            accumShape.add(shape[shapeIndex++]);
        }
        int delta = shape.length <= 2 ? shape.length : shape.length - numPointIndexes;
        boolean bl = needsFilledIn = accumShape.size() != accumStrides.size() && accumOffsets.size() != accumShape.size();
        while (accumOffsets.size() < delta && needsFilledIn) {
            accumOffsets.add(0);
        }
        while (accumShape.size() < 2) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                accumShape.add(0, 1);
                continue;
            }
            accumShape.add(1);
        }
        while (strideIndex < accumShape.size()) {
            accumStrides.add(this.arr.stride(strideIndex++));
        }
        if (newAxesPrepend > 0) {
            for (int i = 0; i < newAxesPrepend; ++i) {
                accumShape.add(0, 1);
                accumStrides.add(0, 0);
                accumOffsets.add(0, 0);
            }
        }
        int numAdded = 0;
        for (int i = 0; i < prependNewAxes.size(); ++i) {
            accumShape.add((Integer)prependNewAxes.get(i) - numAdded, 1);
            accumStrides.add((Integer)prependNewAxes.get(i) - numAdded, 0);
            ++numAdded;
        }
        int trailingZeroRemove = accumOffsets.size() - 1;
        while (accumOffsets.size() > accumShape.size()) {
            if ((Integer)accumOffsets.get(trailingZeroRemove) == 0) {
                accumOffsets.remove(accumOffsets.size() - 1);
            }
            --trailingZeroRemove;
        }
        if (accumStrides.size() < accumOffsets.size()) {
            accumStrides.addAll(pointStrides);
        }
        while (accumOffsets.size() < accumShape.size()) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                accumOffsets.add(0, 0);
                continue;
            }
            accumOffsets.add(0);
        }
        if (Shape.isMatrix(shape) && indexes[0] instanceof PointIndex && indexes[1] instanceof NDArrayIndexAll) {
            Collections.reverse(accumShape);
        }
        this.shapes = Ints.toArray(accumShape);
        boolean isColumnVector = Shape.isColumnVectorShape(this.shapes);
        while (accumStrides.size() < accumOffsets.size()) {
            if (!isColumnVector) {
                accumStrides.add(0, this.arr.elementStride());
                continue;
            }
            accumStrides.add(this.arr.elementStride());
        }
        this.strides = Ints.toArray(accumStrides);
        this.offsets = Ints.toArray(accumOffsets);
        if (numPointIndexes > 0 && !pointStrides.isEmpty()) {
            if (newAxesPrepend >= 1) {
                while (pointStrides.size() < accumOffsets.size()) {
                    pointStrides.add(1);
                }
                for (int i = 0; i < accumStrides.size(); ++i) {
                    if ((Integer)accumStrides.get(i) != 0) continue;
                    pointStrides.set(i, 0);
                }
            }
            while (pointOffsets.size() < pointStrides.size()) {
                pointOffsets.add(0);
            }
            this.offset = this.arr.isRowVector() && !intervalStrides.isEmpty() && (Integer)pointOffsets.get(0) == 0 ? indexes[1].offset() : ArrayUtil.dotProduct(pointOffsets, pointStrides);
        } else {
            this.offset = numIntervals > 0 && this.arr.rank() > 2 ? (encounteredAll && this.arr.size(0) != 1 ? ArrayUtil.dotProduct(accumOffsets, accumStrides) : ArrayUtil.dotProduct(accumOffsets, accumStrides) / numIntervals) : ArrayUtil.calcOffset(accumShape, accumOffsets, accumStrides);
        }
    }

    public INDArray getArr() {
        return this.arr;
    }

    public void setArr(INDArray arr) {
        this.arr = arr;
    }

    public int[] getOffsets() {
        return this.offsets;
    }

    public void setOffsets(int[] offsets) {
        this.offsets = offsets;
    }

    public int[] getShapes() {
        return this.shapes;
    }

    public void setShapes(int[] shapes) {
        this.shapes = shapes;
    }

    public int[] getStrides() {
        return this.strides;
    }

    public void setStrides(int[] strides) {
        this.strides = strides;
    }

    public int getOffset() {
        return this.offset;
    }

    public void setOffset(int offset) {
        this.offset = offset;
    }
}

