/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.indexing;

import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Collections;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.LinearIndex;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.IntervalIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.NDArrayIndexEmpty;
import org.nd4j.linalg.indexing.NewAxis;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.util.ArrayUtil;

public class Indices {
    public static int rowNumber(int index, INDArray arr) {
        int vectors;
        double otherTest = (double)index / (double)arr.size(-1);
        int test = (int)Math.floor(otherTest);
        if (test >= (vectors = arr.vectorsAlongDimension(-1))) {
            return vectors - 1;
        }
        return test;
    }

    public static int linearOffset(int index, INDArray arr) {
        if (arr.ordering() == 'c') {
            double otherTest = (double)index % (double)arr.size(-1);
            int test = (int)Math.floor(otherTest);
            INDArray vec = arr.vectorAlongDimension(test, -1);
            int otherDim = arr.vectorAlongDimension(test, -1).offset() + index;
            return otherDim;
        }
        int majorStride = arr.stride(-2);
        int vectorsAlongDimension = arr.vectorsAlongDimension(-1);
        double rowCalc = (double)(index * majorStride) / (double)arr.length();
        int floor = (int)Math.floor(rowCalc);
        INDArray arrVector = arr.vectorAlongDimension(floor, -1);
        int columnIndex = index % arr.size(-1);
        int retOffset = arrVector.linearIndex(columnIndex);
        return retOffset;
    }

    public static int[] linearIndices(INDArray arr) {
        LinearIndex index = new LinearIndex(arr, arr.dup(), true);
        Nd4j.getExecutioner().iterateOverAllRows(index);
        return index.getIndices();
    }

    public static int[] offsets(int[] shape, INDArrayIndex ... indices) {
        int[] ret = new int[shape.length];
        if (indices.length == shape.length) {
            for (int i = 0; i < indices.length; ++i) {
                ret[i] = indices[i] instanceof NDArrayIndexEmpty ? 0 : indices[i].offset();
            }
            if (ret.length == 1) {
                ret = new int[]{ret[0], 0};
            }
        } else {
            int numPoints = NDArrayIndex.numPoints(indices);
            if (numPoints > 0) {
                int i;
                ArrayList<Integer> nonZeros = new ArrayList<Integer>();
                for (i = 0; i < indices.length; ++i) {
                    if (indices[i].offset() <= 0) continue;
                    nonZeros.add(indices[i].offset());
                }
                if (nonZeros.size() > shape.length) {
                    throw new IllegalStateException("Non zeros greater than shape unable to continue");
                }
                for (i = 0; i < nonZeros.size(); ++i) {
                    ret[i] = (Integer)nonZeros.get(i);
                }
            } else {
                int shapeIndex = 0;
                for (int i = 0; i < indices.length; ++i) {
                    ret[i] = indices[i] instanceof NDArrayIndexEmpty ? 0 : indices[shapeIndex++].offset();
                }
            }
            if (ret.length == 1) {
                ret = new int[]{ret[0], 0};
            }
        }
        return ret;
    }

    public static INDArrayIndex[] fillIn(int[] shape, INDArrayIndex ... indexes) {
        if (shape.length == indexes.length) {
            return indexes;
        }
        INDArrayIndex[] newIndexes = new INDArrayIndex[shape.length];
        System.arraycopy(indexes, 0, newIndexes, 0, indexes.length);
        for (int i = indexes.length; i < shape.length; ++i) {
            newIndexes[i] = NDArrayIndex.interval(0, shape[i]);
        }
        return newIndexes;
    }

    public static INDArrayIndex[] adjustIndices(int[] originalShape, INDArrayIndex ... indexes) {
        if (Shape.isVector(originalShape) && indexes.length == 1) {
            return indexes;
        }
        if (indexes.length < originalShape.length) {
            indexes = Indices.fillIn(originalShape, indexes);
        }
        if (indexes.length > originalShape.length) {
            INDArrayIndex[] ret = new INDArrayIndex[originalShape.length];
            System.arraycopy(indexes, 0, ret, 0, originalShape.length);
            return ret;
        }
        if (indexes.length == originalShape.length) {
            return indexes;
        }
        for (int i = 0; i < indexes.length; ++i) {
            if (indexes[i].end() < originalShape[i] && !(indexes[i] instanceof NDArrayIndexAll)) continue;
            indexes[i] = NDArrayIndex.interval(0, originalShape[i] - 1);
        }
        return indexes;
    }

    public static int[] strides(char ordering, NDArrayIndex ... indexes) {
        return Nd4j.getStrides(Indices.shape(indexes), ordering);
    }

    public static int[] shape(INDArrayIndex ... indices) {
        int[] ret = new int[indices.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = indices[i].length();
        }
        ArrayList<Integer> nonZeros = new ArrayList<Integer>();
        for (int i = 0; i < ret.length; ++i) {
            if (ret[i] <= 0) continue;
            nonZeros.add(ret[i]);
        }
        return ArrayUtil.toArray(nonZeros);
    }

    public static boolean isContiguous(int[] indices, int diff) {
        if (indices.length < 1) {
            return true;
        }
        for (int i = 1; i < indices.length; ++i) {
            if (Math.abs(indices[i] - indices[i - 1]) <= diff) continue;
            return false;
        }
        return true;
    }

    public static INDArrayIndex[] createFromStartAndEnd(INDArray start, INDArray end) {
        if (start.length() != end.length()) {
            throw new IllegalArgumentException("Start length must be equal to end length");
        }
        INDArrayIndex[] indexes = new INDArrayIndex[start.length()];
        for (int i = 0; i < indexes.length; ++i) {
            indexes[i] = NDArrayIndex.interval(start.getInt(i), end.getInt(i));
        }
        return indexes;
    }

    public static INDArrayIndex[] createFromStartAndEnd(INDArray start, INDArray end, boolean inclusive) {
        if (start.length() != end.length()) {
            throw new IllegalArgumentException("Start length must be equal to end length");
        }
        INDArrayIndex[] indexes = new INDArrayIndex[start.length()];
        for (int i = 0; i < indexes.length; ++i) {
            indexes[i] = NDArrayIndex.interval(start.getInt(i), end.getInt(i), inclusive);
        }
        return indexes;
    }

    public static int[] shape(int[] shape, INDArrayIndex ... indices) {
        int i;
        int newAxesPrepend = 0;
        boolean encounteredAll = false;
        ArrayList<Integer> accumShape = new ArrayList<Integer>();
        int shapeIndex = 0;
        ArrayList<Integer> prependNewAxes = new ArrayList<Integer>();
        for (i = 0; i < indices.length; ++i) {
            INDArrayIndex idx = indices[i];
            if (idx instanceof NDArrayIndexAll) {
                encounteredAll = true;
            }
            if (idx instanceof PointIndex) {
                ++shapeIndex;
                continue;
            }
            if (idx instanceof NewAxis) {
                if (encounteredAll) {
                    prependNewAxes.add(i);
                    continue;
                }
                ++newAxesPrepend;
                continue;
            }
            if (idx instanceof IntervalIndex && !(idx instanceof NDArrayIndexAll) || idx instanceof SpecifiedIndex) {
                accumShape.add(idx.length());
                ++shapeIndex;
                continue;
            }
            accumShape.add(shape[shapeIndex]);
            ++shapeIndex;
        }
        while (shapeIndex < shape.length) {
            accumShape.add(shape[shapeIndex++]);
        }
        while (accumShape.size() < 2) {
            accumShape.add(1);
        }
        if (indices.length == 1 && indices[0] instanceof PointIndex && shape.length == 2) {
            Collections.reverse(accumShape);
        }
        if (newAxesPrepend > 0) {
            for (i = 0; i < newAxesPrepend; ++i) {
                accumShape.add(0, 1);
            }
        }
        for (i = 0; i < prependNewAxes.size(); ++i) {
            accumShape.add((Integer)prependNewAxes.get(i) - i, 1);
        }
        return Ints.toArray(accumShape);
    }

    public static int[] stride(INDArray arr, INDArrayIndex[] indexes, int ... shape) {
        int i;
        ArrayList<Integer> strides = new ArrayList<Integer>();
        int strideIndex = 0;
        ArrayList prependNewAxes = new ArrayList();
        for (i = 0; i < indexes.length; ++i) {
            if (indexes[i] instanceof PointIndex) {
                ++strideIndex;
                continue;
            }
            if (!(indexes[i] instanceof NewAxis)) continue;
        }
        for (i = 0; i < prependNewAxes.size(); ++i) {
            strides.add((Integer)prependNewAxes.get(i) - i, 1);
        }
        return Ints.toArray(strides);
    }

    public static boolean isScalar(INDArray indexOver, INDArrayIndex ... indexes) {
        boolean allOneLength = true;
        for (int i = 0; i < indexes.length; ++i) {
            allOneLength = allOneLength && indexes[i].length() == 1;
        }
        int numNewAxes = NDArrayIndex.numNewAxis(indexes);
        if (allOneLength && numNewAxes == 0 && indexes.length == indexOver.rank()) {
            return true;
        }
        if (allOneLength && indexes.length == indexOver.rank() - numNewAxes) {
            return allOneLength;
        }
        return allOneLength;
    }
}

