/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.util;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

public class NDArrayMath {
    public static int offsetForSlice(INDArray arr, int slice) {
        return slice * NDArrayMath.lengthPerSlice(arr);
    }

    public static int lengthPerSlice(INDArray arr, int ... dimension) {
        int[] remove = ArrayUtil.removeIndex(arr.shape(), dimension);
        return ArrayUtil.prod(remove);
    }

    public static int lengthPerSlice(INDArray arr) {
        return NDArrayMath.lengthPerSlice(arr, 0);
    }

    public static int numVectors(INDArray arr) {
        if (arr.rank() == 1) {
            return 1;
        }
        if (arr.rank() == 2) {
            return arr.size(0);
        }
        int prod = 1;
        for (int i = 0; i < arr.rank() - 1; ++i) {
            prod *= arr.size(i);
        }
        return prod;
    }

    public static int vectorsPerSlice(INDArray arr) {
        if (arr.rank() > 2) {
            return ArrayUtil.prod(new int[]{arr.size(-1), arr.size(-2)});
        }
        return arr.slices();
    }

    public static int tensorsPerSlice(INDArray arr, int[] tensorShape) {
        return NDArrayMath.lengthPerSlice(arr) / ArrayUtil.prod(tensorShape);
    }

    public static int matricesPerSlice(INDArray arr) {
        if (arr.rank() == 3) {
            return 1;
        }
        if (arr.rank() > 3) {
            int ret = 1;
            for (int i = 1; i < arr.rank() - 2; ++i) {
                ret *= arr.size(i);
            }
            return ret;
        }
        return arr.size(-2);
    }

    public static int vectorsPerSlice(INDArray arr, int ... rank) {
        if (arr.rank() > 2) {
            return arr.size(-2) * arr.size(-1);
        }
        return arr.size(-1);
    }

    public static int sliceOffsetForTensor(int index, INDArray arr, int[] tensorShape) {
        int tensorLength = ArrayUtil.prod(tensorShape);
        int offset = index * tensorLength / NDArrayMath.lengthPerSlice(arr);
        return offset;
    }

    public static int mapIndexOntoTensor(int index, INDArray arr, int ... rank) {
        int ret = index * ArrayUtil.prod(ArrayUtil.removeIndex(arr.shape(), rank));
        return ret;
    }

    public static int mapIndexOntoVector(int index, INDArray arr) {
        int ret = index * arr.size(-1);
        return ret;
    }
}

