package oracle.pgx.api.mllib;

import oracle.pgx.api.internal.algorithm.arguments.SpeakerListenerLabelPropagationArguments;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.ActivationFunction;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.WeightInitScheme;

/* loaded from: input_file:oracle/pgx/api/mllib/GraphWisePredictionLayerConfigBuilder.class */
public class GraphWisePredictionLayerConfigBuilder {
    private GraphWisePredictionLayerConfig predictionLayerConfig = new GraphWisePredictionLayerConfig();

    public GraphWisePredictionLayerConfigBuilder setHiddenDimension(Integer num) {
        this.predictionLayerConfig.setHiddenDimension(num);
        return this;
    }

    public GraphWisePredictionLayerConfigBuilder setActivationFunction(ActivationFunction activationFunction) {
        this.predictionLayerConfig.setActivationFunction(activationFunction);
        return this;
    }

    public GraphWisePredictionLayerConfigBuilder setWeightInitScheme(WeightInitScheme weightInitScheme) {
        this.predictionLayerConfig.setWeightInitScheme(weightInitScheme);
        return this;
    }

    public GraphWisePredictionLayerConfigBuilder setDropoutRate(double d) {
        this.predictionLayerConfig.setDropoutRate(d);
        return this;
    }

    private void validateDropoutRate() {
        if (this.predictionLayerConfig.getDropoutRate() < SpeakerListenerLabelPropagationArguments.THRESHOLD || this.predictionLayerConfig.getDropoutRate() >= 1.0d) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("INVALID_DROPOUT_RATE", new Object[]{Double.valueOf(this.predictionLayerConfig.getDropoutRate())}));
        }
    }

    public GraphWisePredictionLayerConfig build() {
        validateDropoutRate();
        return this.predictionLayerConfig;
    }
}
