/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public final class Blocks {
    private Blocks() {
    }

    public static NDArray batchFlatten(NDArray array) {
        long batch = array.size(0);
        if (batch == 0L) {
            return array.reshape(batch, array.getShape().slice(1).size());
        }
        return array.reshape(batch, -1L);
    }

    public static NDArray batchFlatten(NDArray array, long size) {
        return array.reshape(-1L, size);
    }

    public static Block batchFlattenBlock() {
        return LambdaBlock.singleton(Blocks::batchFlatten, "batchFlatten");
    }

    public static Block batchFlattenBlock(long size) {
        return LambdaBlock.singleton(array -> Blocks.batchFlatten(array, size), "batchFlatten");
    }

    public static Block identityBlock() {
        return new LambdaBlock(x -> x, "identity");
    }

    public static Block onesBlock(PairList<DataType, Shape> shapes, String[] names) {
        return new LambdaBlock(a -> {
            Shape[] inShapes = a.getShapes();
            NDManager manager = a.getManager();
            NDList list = new NDList(shapes.size());
            int index = 0;
            for (Pair pair : shapes) {
                long[] shape = (long[])((Shape)pair.getValue()).getShape().clone();
                for (int i = 0; i < shape.length; ++i) {
                    if (shape[i] != -1L) continue;
                    shape[i] = inShapes[index].get(i);
                }
                DataType dataType = (DataType)((Object)((Object)pair.getKey()));
                NDArray arr = manager.ones(new Shape(shape), dataType);
                if (names.length == list.size()) {
                    arr.setName(names[index++]);
                }
                list.add(arr);
            }
            return list;
        }, "ones");
    }

    public static String describe(Block block, String blockName, int beginAxis) {
        Shape[] inputShapes = block.isInitialized() ? block.getInputShapes() : null;
        Shape[] outputShapes = inputShapes != null ? block.getOutputShapes(inputShapes) : null;
        StringBuilder sb = new StringBuilder(200);
        if (block instanceof LambdaBlock && !"anonymous".equals(((LambdaBlock)block).getName())) {
            sb.append(((LambdaBlock)block).getName());
        } else if (blockName != null) {
            sb.append(blockName);
        } else {
            sb.append(block.getClass().getSimpleName());
        }
        if (inputShapes != null) {
            sb.append(Stream.of(inputShapes).map(shape -> shape.slice(beginAxis).toString()).collect(Collectors.joining("+")));
        }
        if (!block.getChildren().isEmpty()) {
            sb.append(" {\n");
            for (Pair pair : block.getChildren()) {
                String child = Blocks.describe((Block)pair.getValue(), ((String)pair.getKey()).substring(2), beginAxis);
                sb.append(child.replaceAll("(?m)^", "\t")).append('\n');
            }
            sb.append('}');
        }
        if (outputShapes != null) {
            sb.append(" -> ");
            sb.append(Stream.of(outputShapes).map(shape -> shape.slice(beginAxis).toString()).collect(Collectors.joining("+")));
        }
        return sb.toString();
    }
}

