/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public class Cross
extends DynamicCustomOp {
    public Cross() {
    }

    public Cross(SameDiff sameDiff, SDVariable[] args) {
        super(null, sameDiff, args, false);
    }

    public Cross(INDArray a, INDArray b, INDArray out) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{a, b};
        if (out == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = out;
        }
        super(null, iNDArrayArray2, iNDArrayArray, null, (int[])null);
    }

    @Override
    public String opName() {
        return "cross";
    }

    @Override
    public String tensorflowName() {
        return "Cross";
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> gradients) {
        SDVariable grad = gradients.get(0);
        SDVariable a = this.larg();
        SDVariable b = this.rarg();
        SDVariable ones = this.sameDiff.onesLike(a);
        SDVariable gradLeft = grad.mul(this.sameDiff.math().cross(b, ones));
        SDVariable gradRight = grad.mul(this.sameDiff.math().cross(ones, a));
        return Arrays.asList(gradLeft, gradRight);
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        Preconditions.checkState((dataTypes.size() == 2 ? 1 : 0) != 0, (String)"Expected list with exactly 2 datatype for %s, got %s", this.getClass(), dataTypes);
        return Collections.singletonList(dataTypes.get(0));
    }
}

