package oracle.pgx.api.mllib;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import oracle.pgx.api.mllib.GraphWiseModel;
import oracle.pgx.api.mllib.GraphWiseModelBuilder;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWiseModelConfig;

/* loaded from: input_file:oracle/pgx/api/mllib/GraphWiseModelBuilder.class */
public abstract class GraphWiseModelBuilder<Model extends GraphWiseModel, Config extends GraphWiseModelConfig, Self extends GraphWiseModelBuilder> {
    Config modelConfig;

    public Self setBatchSize(int i) {
        this.modelConfig.setBatchSize(i);
        return getThis();
    }

    public Self setNumEpochs(int i) {
        this.modelConfig.setNumEpochs(i);
        return getThis();
    }

    public Self setLearningRate(double d) {
        this.modelConfig.setLearningRate(d);
        return getThis();
    }

    public Self setWeightDecay(double d) {
        this.modelConfig.setWeightDecay(d);
        return getThis();
    }

    public Self setEmbeddingDim(int i) {
        this.modelConfig.setEmbeddingDim(i);
        return getThis();
    }

    public Self setShuffle(boolean z) {
        this.modelConfig.setShuffle(z);
        return getThis();
    }

    public Self setSeed(Integer num) {
        this.modelConfig.setSeed(num.intValue());
        return getThis();
    }

    public Self setConvLayerConfigs(GraphWiseConvLayerConfig... graphWiseConvLayerConfigArr) {
        this.modelConfig.setConvLayerConfigs(graphWiseConvLayerConfigArr);
        return getThis();
    }

    public Self setVertexInputPropertyNames(List<String> list) {
        this.modelConfig.setVertexInputPropertyNames(list);
        return getThis();
    }

    public Self setVertexInputPropertyNames(String... strArr) {
        this.modelConfig.setVertexInputPropertyNames(Arrays.asList(strArr));
        return getThis();
    }

    public Self setEdgeInputPropertyNames(List<String> list) {
        this.modelConfig.setEdgeInputPropertyNames(list);
        return getThis();
    }

    public Self setEdgeInputPropertyNames(String... strArr) {
        this.modelConfig.setEdgeInputPropertyNames(Arrays.asList(strArr));
        return getThis();
    }

    public Self setTargetVertexLabels(List<String> list) {
        this.modelConfig.setTargetVertexLabels(list);
        return getThis();
    }

    public Self setTargetVertexLabels(String... strArr) {
        this.modelConfig.setTargetVertexLabels(Arrays.asList(strArr));
        return getThis();
    }

    public Self setStandardize(boolean z) {
        this.modelConfig.setStandardize(z);
        return getThis();
    }

    public abstract Model build() throws InterruptedException, ExecutionException;

    protected abstract Self getThis();

    /* JADX INFO: Access modifiers changed from: protected */
    public void validateAll() {
        if (this.modelConfig.getVertexInputPropertyNames() == null && this.modelConfig.getEdgeInputPropertyNames() == null) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("NO_INPUT_PROPERTIES", new Object[0]));
        }
        validateConvLayerConfigs();
        validateSimpleParams();
    }

    private void validateSimpleParams() {
        if (this.modelConfig.getEmbeddingDim() <= 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("INVALID_EMBEDDING_DIMENSION", new Object[]{Integer.valueOf(this.modelConfig.getEmbeddingDim())}));
        }
        if (this.modelConfig.getNumEpochs() <= 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("INVALID_NUM_EPOCHS", new Object[]{Integer.valueOf(this.modelConfig.getNumEpochs())}));
        }
        if (this.modelConfig.getBatchSize() <= 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("INVALID_BATCH_SIZE", new Object[]{Integer.valueOf(this.modelConfig.getBatchSize())}));
        }
    }

    private void validateConvLayerConfigs() {
        boolean z = this.modelConfig.getEdgeInputPropertyNames() != null;
        int i = 0;
        for (GraphWiseConvLayerConfig graphWiseConvLayerConfig : this.modelConfig.getConvLayerConfigs()) {
            if (graphWiseConvLayerConfig.getNumSampledNeighbors() < 1) {
                throw new IllegalArgumentException(ErrorMessages.getMessage("ENCODER_MUST_SAMPLE_AT_LEAST_ONE_NEIGHBOR", new Object[]{Integer.valueOf(graphWiseConvLayerConfig.getNumSampledNeighbors())}));
            }
            if (graphWiseConvLayerConfig.getVertexToVertexConnection() != null || graphWiseConvLayerConfig.getEdgeToVertexConnection() != null || graphWiseConvLayerConfig.getVertexToEdgeConnection() != null || graphWiseConvLayerConfig.getEdgeToEdgeConnection() != null) {
                validateConnections(graphWiseConvLayerConfig.getVertexToVertexConnection().booleanValue(), graphWiseConvLayerConfig.getEdgeToVertexConnection().booleanValue(), graphWiseConvLayerConfig.getVertexToEdgeConnection().booleanValue(), graphWiseConvLayerConfig.getEdgeToEdgeConnection().booleanValue(), z, i);
            }
            i++;
        }
    }

    private static void validateConnections(boolean z, boolean z2, boolean z3, boolean z4, boolean z5, int i) {
        if (!z5 && (z2 || z4 || z3)) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("EDGE_CONNECTION_ENABLED_WITHOUT_EDGES", new Object[]{Integer.valueOf(i)}));
        }
        if (i == 0 && !z && !z2) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("NO_VERTEX_CONNECTION_ENABLED", new Object[0]));
        }
        if (!z && !z2 && !z3 && !z4) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("NO_GRAPHWISE_CONNECTION_ENABLED", new Object[]{Integer.valueOf(i)}));
        }
    }
}
