/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.util.Pair;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public final class Blocks {
    private Blocks() {
    }

    public static NDArray batchFlatten(NDArray array) {
        long batch = array.size(0);
        if (batch == 0L) {
            return array.reshape(batch, array.getShape().slice(1).size());
        }
        return array.reshape(batch, -1L);
    }

    public static NDArray batchFlatten(NDArray array, long size) {
        return array.reshape(-1L, size);
    }

    public static Block batchFlattenBlock() {
        return LambdaBlock.singleton(Blocks::batchFlatten, "batchFlatten");
    }

    public static Block batchFlattenBlock(long size) {
        return LambdaBlock.singleton(array -> Blocks.batchFlatten(array, size), "batchFlatten");
    }

    public static Block identityBlock() {
        return new LambdaBlock(x -> x, "identity");
    }

    public static String describe(Block block, String blockName, int beginAxis) {
        Shape[] inputShapes = block.isInitialized() ? block.getInputShapes() : null;
        Shape[] outputShapes = inputShapes != null ? block.getOutputShapes(inputShapes) : null;
        StringBuilder sb = new StringBuilder(200);
        if (block instanceof LambdaBlock && !"anonymous".equals(((LambdaBlock)block).getName())) {
            sb.append(((LambdaBlock)block).getName());
        } else if (blockName != null) {
            sb.append(blockName);
        } else {
            sb.append(block.getClass().getSimpleName());
        }
        if (inputShapes != null) {
            sb.append(Stream.of(inputShapes).map(shape -> shape.slice(beginAxis).toString()).collect(Collectors.joining("+")));
        }
        if (!block.getChildren().isEmpty()) {
            sb.append(" {\n");
            for (Pair pair : block.getChildren()) {
                String child = Blocks.describe((Block)pair.getValue(), ((String)pair.getKey()).substring(2), beginAxis);
                sb.append(child.replaceAll("(?m)^", "\t")).append('\n');
            }
            sb.append('}');
        }
        if (outputShapes != null) {
            sb.append(" -> ");
            sb.append(Stream.of(outputShapes).map(shape -> shape.slice(beginAxis).toString()).collect(Collectors.joining("+")));
        }
        return sb.toString();
    }
}

