/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public abstract class AbstractBlock
implements Block {
    protected Shape[] inputShapes;
    protected List<String> inputNames = Collections.singletonList("data");

    @Override
    public PairList<String, Shape> describeInput() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("Parameter of this block are not initialised");
        }
        return new PairList<String, Shape>(this.inputNames, Arrays.asList(this.inputShapes));
    }

    @Override
    public void setInitializer(Initializer initializer) {
        for (Parameter parameter : this.getDirectParameters()) {
            parameter.setInitializer(initializer, false);
        }
        for (Block child : this.getChildren().values()) {
            child.setInitializer(initializer);
        }
    }

    @Override
    public void setInitializer(Initializer initializer, String paramName) {
        Parameter parameter = this.getDirectParameters().stream().filter(pair -> pair.getName().equals(paramName)).findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find parameter " + paramName));
        parameter.setInitializer(initializer, true);
    }

    @Override
    public ParameterList getParameters() {
        ParameterList parameters = new ParameterList();
        List<Parameter> directParams = this.getDirectParameters();
        directParams.forEach(param -> parameters.add(param.getName(), param));
        ParameterList childrenParameters = this.getChildrenParameters();
        childrenParameters.forEach(parameters::add);
        return parameters;
    }

    protected void beforeInitialize(Shape[] inputShapes) {
        this.inputShapes = inputShapes;
    }

    @Override
    public boolean isInitialized() {
        for (Parameter param : this.getParameters().values()) {
            if (param.isInitialized()) continue;
            return false;
        }
        return true;
    }

    @Override
    public void clear() {
        this.getParameters().forEach(param -> ((Parameter)param.getValue()).close());
    }

    @Override
    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    private ParameterList getChildrenParameters() {
        ParameterList parameters = new ParameterList();
        for (Pair childPair : this.getChildren()) {
            for (Pair paramPair : ((Block)childPair.getValue()).getParameters()) {
                parameters.add((String)childPair.getKey() + "_" + (String)paramPair.getKey(), paramPair.getValue());
            }
        }
        return parameters;
    }
}

