/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.util;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;

public class NDArrayUtil {
    public static INDArray exp(INDArray toExp) {
        return NDArrayUtil.expi(toExp.dup());
    }

    public static INDArray expi(INDArray toExp) {
        INDArray flattened = toExp.ravel();
        for (int i = 0; i < flattened.length(); ++i) {
            flattened.put(i, Nd4j.scalar(Math.exp((Double)flattened.getScalar(i).element())));
        }
        return flattened.reshape(toExp.shape());
    }

    public static INDArray center(INDArray arr, int[] shape) {
        if (arr.length() < ArrayUtil.prod(shape)) {
            return arr;
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 1) continue;
            shape[i] = 1;
        }
        INDArray shapeMatrix = ArrayUtil.toNDArray(shape);
        INDArray currShape = ArrayUtil.toNDArray(arr.shape());
        INDArray startIndex = Transforms.floor(currShape.sub(shapeMatrix).divi(Nd4j.scalar(2.0f)));
        INDArray endIndex = startIndex.add(shapeMatrix);
        INDArrayIndex[] indexes = Indices.createFromStartAndEnd(startIndex, endIndex);
        if (shapeMatrix.length() > 1) {
            return arr.get(indexes);
        }
        INDArray ret = Nd4j.create(new int[]{(int)shapeMatrix.getDouble(0)});
        int start = (int)startIndex.getDouble(0);
        int end = (int)endIndex.getDouble(0);
        int count = 0;
        for (int i = start; i < end; ++i) {
            ret.putScalar(count++, arr.getDouble(i));
        }
        return ret;
    }
}

