/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.Parameter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public abstract class TransformerBaseBlock
extends AbstractBlock {
    protected int version;
    protected LinkedHashMap<String, Block> children = new LinkedHashMap();
    protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap();
    protected LinkedHashMap<String, Function<Shape[], Shape>> parameterShapeCallbacks = new LinkedHashMap();

    public TransformerBaseBlock(int version) {
        this.version = version;
    }

    public int getVersion() {
        return this.version;
    }

    protected <B extends Block> B addChildBlock(String name, B block) {
        this.children.put(name, block);
        return block;
    }

    protected <P extends Parameter> P addParameter(P parameter) {
        return this.addParameter(parameter, (Function<Shape[], Shape>)null);
    }

    protected <P extends Parameter> P addParameter(P parameter, Shape shape) {
        return this.addParameter(parameter, (Shape[] inputShapes) -> shape);
    }

    protected <P extends Parameter> P addParameter(P parameter, Function<Shape[], Shape> shapeCallback) {
        this.parameters.put(parameter.getName(), parameter);
        this.parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
        return parameter;
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Function<Shape[], Shape> callback = this.parameterShapeCallbacks.get(name);
        if (callback == null) {
            Parameter parameter = this.parameters.get(name);
            if (parameter == null) {
                throw new IllegalArgumentException("No parameter named " + name + " found in this block.");
            }
            throw new IllegalStateException("No shape initializer for parameter " + name + "found. Either pass an initializer for the shape when adding the parameter or override getParameterShape in the subclass.");
        }
        return callback.apply(inputShapes);
    }

    @Override
    public BlockList getChildren() {
        return new BlockList((Map<String, Block>)this.children);
    }

    @Override
    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        for (Parameter parameter : this.getDirectParameters()) {
            parameter.initialize(manager, dataType, inputShapes);
        }
        this.initializeChildBlocks(manager, dataType, inputShapes);
        return this.getOutputShapes(manager, inputShapes);
    }

    public abstract void initializeChildBlocks(NDManager var1, DataType var2, Shape ... var3);

    @Override
    public List<Parameter> getDirectParameters() {
        return new ArrayList<Parameter>(this.parameters.values());
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.write(this.version);
        for (Parameter parameter : this.parameters.values()) {
            parameter.save(os);
        }
        for (Block child : this.children.values()) {
            child.saveParameters(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        int loadVersion = is.readInt();
        if (loadVersion != this.getVersion()) {
            throw new MalformedModelException("Cannot load parameters for " + this.getClass().getCanonicalName() + ", expected version " + this.getVersion() + ", got " + loadVersion + ".");
        }
        for (Parameter parameter : this.parameters.values()) {
            parameter.load(manager, is);
        }
        for (Block child : this.children.values()) {
            child.loadParameters(manager, is);
        }
    }
}

