/*
 * Decompiled with CFR 0.152.
 */
package oracle.pgx.api.mllib;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import oracle.pgx.api.mllib.GraphWiseModel;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWiseModelConfig;

public abstract class GraphWiseModelBuilder<Model extends GraphWiseModel, Config extends GraphWiseModelConfig, Self extends GraphWiseModelBuilder> {
    Config modelConfig;

    public Self setBatchSize(int batchSize) {
        this.modelConfig.setBatchSize(batchSize);
        return this.getThis();
    }

    public Self setNumEpochs(int numEpochs) {
        this.modelConfig.setNumEpochs(numEpochs);
        return this.getThis();
    }

    public Self setLearningRate(double learningRate) {
        this.modelConfig.setLearningRate(learningRate);
        return this.getThis();
    }

    public Self setWeightDecay(double weightDecay) {
        this.modelConfig.setWeightDecay(weightDecay);
        return this.getThis();
    }

    public Self setEmbeddingDim(int embeddingDim) {
        this.modelConfig.setEmbeddingDim(embeddingDim);
        return this.getThis();
    }

    public Self setShuffle(boolean shuffle) {
        this.modelConfig.setShuffle(shuffle);
        return this.getThis();
    }

    public Self setSeed(Integer seed) {
        this.modelConfig.setSeed(seed.intValue());
        return this.getThis();
    }

    public Self setConvLayerConfigs(GraphWiseConvLayerConfig ... layerConfigs) {
        this.modelConfig.setConvLayerConfigs(layerConfigs);
        return this.getThis();
    }

    public Self setVertexInputPropertyNames(List<String> vertexInputPropertyNames) {
        this.modelConfig.setVertexInputPropertyNames(vertexInputPropertyNames);
        return this.getThis();
    }

    public Self setVertexInputPropertyNames(String ... vertexInputPropertyNames) {
        this.modelConfig.setVertexInputPropertyNames(Arrays.asList(vertexInputPropertyNames));
        return this.getThis();
    }

    public Self setEdgeInputPropertyNames(List<String> edgeInputPropertyNames) {
        this.modelConfig.setEdgeInputPropertyNames(edgeInputPropertyNames);
        return this.getThis();
    }

    public Self setEdgeInputPropertyNames(String ... edgeInputPropertyNames) {
        this.modelConfig.setEdgeInputPropertyNames(Arrays.asList(edgeInputPropertyNames));
        return this.getThis();
    }

    public Self setTargetVertexLabels(List<String> targetVertexLabels) {
        this.modelConfig.setTargetVertexLabels(targetVertexLabels);
        return this.getThis();
    }

    public Self setTargetVertexLabels(String ... targetVertexLabels) {
        this.modelConfig.setTargetVertexLabels(Arrays.asList(targetVertexLabels));
        return this.getThis();
    }

    public Self setStandardize(boolean standardize) {
        this.modelConfig.setStandardize(standardize);
        return this.getThis();
    }

    public abstract Model build() throws InterruptedException, ExecutionException;

    protected abstract Self getThis();

    protected void validateAll() {
        if (this.modelConfig.getVertexInputPropertyNames() == null && this.modelConfig.getEdgeInputPropertyNames() == null) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"NO_INPUT_PROPERTIES", (Object[])new Object[0]));
        }
        this.validateConvLayerConfigs();
        this.validateSimpleParams();
    }

    private void validateSimpleParams() {
        if (this.modelConfig.getEmbeddingDim() <= 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"INVALID_EMBEDDING_DIMENSION", (Object[])new Object[]{this.modelConfig.getEmbeddingDim()}));
        }
        if (this.modelConfig.getNumEpochs() <= 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"INVALID_NUM_EPOCHS", (Object[])new Object[]{this.modelConfig.getNumEpochs()}));
        }
        if (this.modelConfig.getBatchSize() <= 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"INVALID_BATCH_SIZE", (Object[])new Object[]{this.modelConfig.getBatchSize()}));
        }
    }

    private void validateConvLayerConfigs() {
        boolean loadEdges = this.modelConfig.getEdgeInputPropertyNames() != null;
        int hop = 0;
        for (GraphWiseConvLayerConfig convLayerConfig : this.modelConfig.getConvLayerConfigs()) {
            if (convLayerConfig.getNumSampledNeighbors() < 1) {
                throw new IllegalArgumentException(ErrorMessages.getMessage((String)"ENCODER_MUST_SAMPLE_AT_LEAST_ONE_NEIGHBOR", (Object[])new Object[]{convLayerConfig.getNumSampledNeighbors()}));
            }
            if (convLayerConfig.getVertexToVertexConnection() != null || convLayerConfig.getEdgeToVertexConnection() != null || convLayerConfig.getVertexToEdgeConnection() != null || convLayerConfig.getEdgeToEdgeConnection() != null) {
                GraphWiseModelBuilder.validateConnections(convLayerConfig.getVertexToVertexConnection(), convLayerConfig.getEdgeToVertexConnection(), convLayerConfig.getVertexToEdgeConnection(), convLayerConfig.getEdgeToEdgeConnection(), loadEdges, hop);
            }
            ++hop;
        }
    }

    private static void validateConnections(boolean enableVertexToVertexConnection, boolean enableEdgeToVertexConnection, boolean enableVertexToEdgeConnection, boolean enableEdgeToEdgeConnection, boolean loadEdges, int hop) {
        if (!loadEdges && (enableEdgeToVertexConnection || enableEdgeToEdgeConnection || enableVertexToEdgeConnection)) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"EDGE_CONNECTION_ENABLED_WITHOUT_EDGES", (Object[])new Object[]{hop}));
        }
        if (hop == 0 && !enableVertexToVertexConnection && !enableEdgeToVertexConnection) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"NO_VERTEX_CONNECTION_ENABLED", (Object[])new Object[0]));
        }
        if (!(enableVertexToVertexConnection || enableEdgeToVertexConnection || enableVertexToEdgeConnection || enableEdgeToEdgeConnection)) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"NO_GRAPHWISE_CONNECTION_ENABLED", (Object[])new Object[]{hop}));
        }
    }
}

