/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ParallelBlock
extends AbstractBlock {
    private static final byte VERSION = 2;
    private List<Block> blocks;
    private Function<List<NDList>, NDList> function;

    public ParallelBlock(Function<List<NDList>, NDList> function) {
        this.function = function;
        this.blocks = new ArrayList<Block>();
    }

    public ParallelBlock(Function<List<NDList>, NDList> function, List<Block> blocks) {
        this.function = function;
        this.blocks = blocks;
    }

    public final ParallelBlock addAll(Block ... blocks) {
        this.blocks.addAll(Arrays.asList(blocks));
        return this;
    }

    public final ParallelBlock addAll(Collection<Block> blocks) {
        this.blocks.addAll(blocks);
        return this;
    }

    public final ParallelBlock add(Block block) {
        this.blocks.add(block);
        return this;
    }

    public final ParallelBlock add(Function<NDList, NDList> f) {
        this.blocks.add(new LambdaBlock(f));
        return this;
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        return this.function.apply(this.blocks.stream().map(block -> block.forward(parameterStore, inputs, training, params)).collect(Collectors.toList()));
    }

    @Override
    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        for (Block child : this.getChildren().values()) {
            child.initialize(manager, dataType, inputShapes);
        }
        return this.getOutputShapes(manager, inputShapes);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        if (this.blocks.isEmpty()) {
            throw new IllegalArgumentException("The sequential block is empty");
        }
        try (NDManager subManager = manager.newSubManager();){
            ArrayList<NDList> inputs = new ArrayList<NDList>();
            for (Block block : this.blocks) {
                Shape[] shapes = block.getOutputShapes(manager, inputShapes);
                NDList output = new NDList(shapes.length);
                for (Shape shape : shapes) {
                    output.add(subManager.create(shape));
                }
                inputs.add(output);
            }
            NDList output = this.function.apply(inputs);
            Shape[] outputShapes = new Shape[output.size()];
            for (int i = 0; i < output.size(); ++i) {
                outputShapes[i] = ((NDArray)output.get(i)).getShape();
            }
            Shape[] shapeArray = outputShapes;
            return shapeArray;
        }
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.emptyList();
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        throw new IllegalArgumentException("ParallelBlock have no parameters");
    }

    @Override
    public BlockList getChildren() {
        int size = this.blocks.size();
        BlockList children = new BlockList(size);
        int precision = (int)Math.log10(size) + 1;
        String format = "%0" + precision + "d:%s";
        for (int i = 0; i < size; ++i) {
            Block block = this.blocks.get(i);
            String name = String.format(format, i, block.getClass().getSimpleName());
            children.add(name, block);
        }
        return children;
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(2);
        this.saveInputShapes(os);
        for (Block block : this.blocks) {
            block.saveParameters(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        for (Block block : this.blocks) {
            block.loadParameters(manager, is);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Parallel(\n");
        for (Block block : this.blocks) {
            String blockString = block.toString().replaceAll("(?m)^", "\t");
            sb.append(blockString).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}

