/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph.vertex.impl;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Or;
import org.nd4j.linalg.factory.Nd4j;

public class ElementWiseVertex
extends BaseGraphVertex {
    private Op op;
    private int nInForwardPass;

    public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, Op op) {
        this(graph, name, vertexIndex, null, null, op);
    }

    public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Op op) {
        super(graph, name, vertexIndex, inputVertices, outputVertices);
        this.op = op;
    }

    @Override
    public boolean hasLayer() {
        return false;
    }

    @Override
    public boolean isOutputVertex() {
        return false;
    }

    @Override
    public Layer getLayer() {
        return null;
    }

    @Override
    public INDArray doForward(boolean training) {
        if (!this.canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set");
        }
        this.nInForwardPass = this.inputs.length;
        if (this.inputs.length == 1) {
            return this.inputs[0];
        }
        switch (this.op) {
            case Add: {
                INDArray sum = this.inputs[0].dup();
                for (int i = 1; i < this.inputs.length; ++i) {
                    sum.addi(this.inputs[i]);
                }
                return sum;
            }
            case Subtract: {
                if (this.inputs.length != 2) {
                    throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
                }
                return this.inputs[0].sub(this.inputs[1]);
            }
            case Product: {
                INDArray product = this.inputs[0].dup();
                for (int i = 1; i < this.inputs.length; ++i) {
                    product.muli(this.inputs[i]);
                }
                return product;
            }
        }
        throw new UnsupportedOperationException("Unknown op: " + (Object)((Object)this.op));
    }

    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
        if (!this.canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.nInForwardPass == 1) {
            return new Pair<Object, INDArray[]>(null, new INDArray[]{this.epsilon});
        }
        switch (this.op) {
            case Add: {
                INDArray[] out = new INDArray[this.nInForwardPass];
                for (int i = 0; i < this.nInForwardPass; ++i) {
                    out[i] = this.epsilon.dup();
                }
                return new Pair<Object, INDArray[]>(null, out);
            }
            case Subtract: {
                INDArray[] out2 = new INDArray[]{this.epsilon, this.epsilon.neg()};
                return new Pair<Object, INDArray[]>(null, out2);
            }
            case Product: {
                INDArray[] out_product = new INDArray[this.nInForwardPass];
                for (int i = 0; i < this.nInForwardPass; ++i) {
                    out_product[i] = this.epsilon.dup();
                    for (int j = 0; j < this.nInForwardPass; ++j) {
                        if (i == j) continue;
                        out_product[i].muli(this.inputs[j]);
                    }
                }
                return new Pair<Object, INDArray[]>(null, out_product);
            }
        }
        throw new UnsupportedOperationException("Unknown op: " + (Object)((Object)this.op));
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
        if (backpropGradientsViewArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        if (maskArrays == null) {
            return new Pair<Object, MaskState>(null, currentMaskState);
        }
        for (INDArray arr : maskArrays) {
            if (arr != null) continue;
            return new Pair<Object, MaskState>(null, currentMaskState);
        }
        if (maskArrays.length == 1) {
            return new Pair<INDArray, MaskState>(maskArrays[0], currentMaskState);
        }
        INDArray ret = maskArrays[0].dup(maskArrays[0].ordering());
        Nd4j.getExecutioner().exec((org.nd4j.linalg.api.ops.Op)new Or(maskArrays[0], maskArrays[1], ret));
        for (int i = 2; i < maskArrays.length; ++i) {
            Nd4j.getExecutioner().exec((org.nd4j.linalg.api.ops.Op)new Or(maskArrays[i], ret, ret));
        }
        return new Pair<INDArray, MaskState>(ret, currentMaskState);
    }

    @Override
    public String toString() {
        return "ElementWiseVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\",op=" + (Object)((Object)this.op) + ")";
    }

    public static enum Op {
        Add,
        Subtract,
        Product;

    }
}

