/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public abstract class AbstractBlock
implements Block {
    protected Shape[] inputShapes;
    protected List<String> inputNames = Collections.singletonList("data");
    protected byte version;
    protected BlockList children = new BlockList();
    protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap();
    protected LinkedHashMap<String, Function<Shape[], Shape>> parameterShapeCallbacks = new LinkedHashMap();

    public AbstractBlock(byte version) {
        this.version = version;
    }

    protected final <B extends Block> B addChildBlock(String name, B block) {
        int childNumber = this.children.size() + 1;
        this.children.add(String.format("%02d%s", childNumber, name), block);
        return block;
    }

    protected final <P extends Parameter> P addParameter(P parameter) {
        return this.addParameter(parameter, (Function<Shape[], Shape>)null);
    }

    protected final <P extends Parameter> P addParameter(P parameter, Shape shape) {
        return this.addParameter(parameter, (Shape[] inputShapes) -> shape);
    }

    protected final <P extends Parameter> P addParameter(P parameter, Function<Shape[], Shape> shapeCallback) {
        this.parameters.put(parameter.getName(), parameter);
        this.parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
        return parameter;
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Function<Shape[], Shape> callback = this.parameterShapeCallbacks.get(name);
        if (callback == null) {
            Parameter parameter = this.parameters.get(name);
            if (parameter == null) {
                throw new IllegalArgumentException("No parameter named " + name + " found in this block.");
            }
            throw new IllegalStateException("No shape initializer for parameter " + name + "found. Either pass an initializer for the shape when adding the parameter or override getParameterShape in the subclass.");
        }
        return callback.apply(inputShapes);
    }

    @Override
    public BlockList getChildren() {
        BlockList defensiveCopy = new BlockList(this.children.size());
        for (Pair entry : this.children) {
            defensiveCopy.add(entry);
        }
        return defensiveCopy;
    }

    @Override
    public PairList<String, Shape> describeInput() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("Parameter of this block are not initialised");
        }
        return new PairList<String, Shape>(this.inputNames, Arrays.asList(this.inputShapes));
    }

    @Override
    public void setInitializer(Initializer initializer) {
        for (Parameter parameter : this.parameters.values()) {
            parameter.setInitializer(initializer, false);
        }
        for (Block child : this.children.values()) {
            child.setInitializer(initializer);
        }
    }

    @Override
    public void setInitializer(Initializer initializer, String paramName) {
        Parameter parameter = this.parameters.get(paramName);
        if (parameter == null) {
            throw new IllegalArgumentException("Could not find parameter " + paramName);
        }
        parameter.setInitializer(initializer, true);
    }

    @Override
    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        for (Parameter parameter : this.parameters.values()) {
            parameter.initialize(manager, dataType, inputShapes);
        }
        this.initializeChildBlocks(manager, dataType, inputShapes);
        return this.getOutputShapes(manager, inputShapes);
    }

    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        if (!this.children.isEmpty()) {
            throw new IllegalStateException(this.getClass().getSimpleName() + " has child blocks but initializeChildBlocks is not overwritten.");
        }
    }

    @Override
    public ParameterList getParameters() {
        ParameterList allParams = this.getDirectParameters();
        for (Pair childPair : this.getChildren()) {
            for (Pair paramPair : ((Block)childPair.getValue()).getParameters()) {
                allParams.add((String)childPair.getKey() + "_" + (String)paramPair.getKey(), paramPair.getValue());
            }
        }
        return allParams;
    }

    @Override
    public ParameterList getDirectParameters() {
        return new ParameterList((Map<String, Parameter>)this.parameters);
    }

    protected void beforeInitialize(Shape[] inputShapes) {
        this.inputShapes = inputShapes;
    }

    @Override
    public boolean isInitialized() {
        for (Parameter param : this.getParameters().values()) {
            if (param.isInitialized()) continue;
            return false;
        }
        return true;
    }

    @Override
    public void clear() {
        this.getParameters().forEach(param -> ((Parameter)param.getValue()).close());
    }

    @Override
    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.write(this.version);
        this.saveMetadata(os);
        for (Parameter parameter : this.parameters.values()) {
            parameter.save(os);
        }
        for (Block child : this.children.values()) {
            child.saveParameters(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte loadVersion = is.readByte();
        this.loadMetadata(loadVersion, is);
        for (Parameter parameter : this.parameters.values()) {
            parameter.load(manager, is);
        }
        for (Block child : this.children.values()) {
            child.loadParameters(manager, is);
        }
    }

    protected void saveMetadata(DataOutputStream os) throws IOException {
        this.saveInputShapes(os);
    }

    protected void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        if (loadVersion != this.version) {
            throw new MalformedModelException("Cannot load parameters for " + this.getClass().getCanonicalName() + ", expected version " + this.version + ", got " + loadVersion + ".");
        }
        this.readInputShapes(is);
    }

    protected void saveInputShapes(DataOutputStream os) throws IOException {
        os.writeInt(this.inputShapes.length);
        for (Shape shape : this.inputShapes) {
            os.write(shape.getEncoded());
        }
    }

    protected void readInputShapes(DataInputStream is) throws IOException {
        int len = is.readInt();
        this.inputShapes = new Shape[len];
        for (int i = 0; i < len; ++i) {
            this.inputShapes[i] = Shape.decode(is);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        String className = this.getClass().getSimpleName();
        if (className.endsWith("Block")) {
            className = className.substring(0, className.length() - 5);
        }
        sb.append(className).append('(');
        if (this.isInitialized()) {
            PairList<String, Shape> inputShapeDescription = this.describeInput();
            this.appendShape(sb, inputShapeDescription.values().toArray(new Shape[0]));
            sb.append(" -> ");
            Shape[] outputShapes = this.getOutputShapes(null, inputShapeDescription.values().toArray(new Shape[0]));
            this.appendShape(sb, outputShapes);
        } else {
            sb.append("Uninitialized");
        }
        sb.append(')');
        return sb.toString();
    }

    private void appendShape(StringBuilder sb, Shape[] shapes) {
        boolean first = true;
        for (Shape shape : shapes) {
            if (first) {
                first = false;
            } else {
                sb.append(", ");
            }
            long[] sh = shape.getShape();
            int length = sh.length;
            if (length == 0) {
                sb.append("()");
                continue;
            }
            int index = 0;
            if (sh[0] == -1L) {
                --length;
                index = 1;
            }
            if (length == 0) {
                sb.append("()");
                continue;
            }
            if (length == 1) {
                sb.append(sh[index]);
                continue;
            }
            sb.append('(');
            for (int i = index; i < sh.length; ++i) {
                if (i > index) {
                    sb.append(", ");
                }
                sb.append(sh[i]);
            }
            sb.append(')');
        }
    }
}

