/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.activation;

import org.nd4j.linalg.api.activation.BaseActivationFunction;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.ArrayOps;
import org.nd4j.linalg.ops.ElementWiseOp;
import org.nd4j.linalg.ops.transforms.Exp;

public class SoftMax
extends BaseActivationFunction {
    private boolean rows;
    private static final long serialVersionUID = -3407472284248637360L;

    public SoftMax(boolean rows) {
        this.rows = rows;
    }

    public SoftMax() {
        this(false);
    }

    public static INDArray softmax(INDArray input, boolean row) {
        if (row) {
            if (input.ordering() == 'f') {
                INDArray max = input.max(1);
                INDArray diff = input.subColumnVector(max);
                new ArrayOps().from(diff).op(Exp.class).build().exec();
                diff.diviColumnVector(diff.sum(1).transpose());
                return diff;
            }
            INDArray max = input.max(1);
            INDArray diff = input.subColumnVector(max);
            new ArrayOps().from(diff).op(Exp.class).build().exec();
            diff.diviColumnVector(diff.sum(1).transpose());
            return diff;
        }
        if (input.ordering() == 'f') {
            INDArray max = input.max(0).transpose();
            INDArray diff = input.subRowVector(max);
            new ArrayOps().from(diff).op(Exp.class).build().exec();
            diff.diviRowVector(diff.sum(0));
            return diff;
        }
        INDArray max = input.max(0).transpose();
        INDArray diff = input.subRowVector(max);
        new ArrayOps().from(diff).op(Exp.class).build().exec();
        diff.diviRowVector(diff.sum(0));
        return diff;
    }

    @Override
    public INDArray apply(INDArray input) {
        return SoftMax.softmax(input, this.rows);
    }

    @Override
    public Class<? extends ElementWiseOp> transformClazz() {
        return null;
    }

    @Override
    public INDArray applyDerivative(INDArray input) {
        if (input instanceof IComplexNDArray) {
            return SoftMax.softmax(input, this.rows).mul(Nd4j.complexOnes(input.shape()).subi(SoftMax.softmax(input, this.rows)));
        }
        return SoftMax.softmax(input, this.rows).mul(Nd4j.ones(input.shape()).subi(SoftMax.softmax(input, this.rows)));
    }
}

