/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.op.nn;

import java.util.ArrayList;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Shapes;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Equal;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;

public class SparseSoftmaxCrossEntropyWithLogits {
    public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossEntropyWithLogits(Scope scope, Operand<T> labels, Operand<U> logits) {
        boolean staticShapesFullyDefined;
        scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits");
        Operand<U> preciseLogits = logits;
        if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) {
            preciseLogits = Cast.create(scope, logits, TFloat32.class, new Cast.Options[0]);
        }
        Shape labelsStaticShape = labels.shape();
        org.tensorflow.op.core.Shape<TInt32> labelsShape = org.tensorflow.op.core.Shape.create(scope, labels);
        Shape logitsShape = logits.shape();
        Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1);
        boolean bl = staticShapesFullyDefined = !labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension();
        if (logitsShape.numDimensions() == 0) {
            throw new IllegalArgumentException(String.format("Logits cannot be scalars - received shape %s.", logitsShape));
        }
        if (!logitsShape.hasUnknownDimension() && !labelsStaticShape.hasUnknownDimension() && labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) {
            throw new IllegalArgumentException(String.format("Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", labelsStaticShape, logitsShape));
        }
        if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) {
            throw new IllegalArgumentException(String.format("Shape mismatch: The shape of labels (received %s) should equal the shape of logits except for the last dimension (received %s).", labelsStaticShape, logitsShape));
        }
        if (logitsShape.numDimensions() == 2) {
            org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits<U> smax = org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create(scope, preciseLogits, labels);
            Operand<Object> loss = smax.loss();
            if (logits.asOutput().type() == TFloat16.class) {
                loss = Cast.create(scope, loss, TFloat16.class, new Cast.Options[0]);
            }
            return loss;
        }
        ArrayList<Op> shapeChecks = new ArrayList<Op>();
        if (!staticShapesFullyDefined) {
            shapeChecks.add(AssertThat.create(scope, Equal.create(scope, org.tensorflow.op.core.Shape.create(scope, labels), Shapes.take(scope, org.tensorflow.op.core.Shape.create(scope, logits), Constant.scalarOf(scope, -1)), new Equal.Options[0]), Collections.singletonList(Constant.scalarOf(scope, "Shape mismatch: The shape of labels  should equal the shape of logits except for the last dimension ")), new AssertThat.Options[0]));
        }
        long numClassses = logitsShape.size(-1);
        preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
        labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));
        scope.withControlDependencies(shapeChecks);
        org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits<U> smax = org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create(scope, preciseLogits, labels);
        Operand<Object> cost = smax.loss();
        cost = Reshape.create(scope, cost, labelsShape);
        if (logits.asOutput().type() == TFloat16.class) {
            cost = Cast.create(scope, cost, TFloat16.class, new Cast.Options[0]);
        }
        return cost;
    }
}

