package oracle.pgx.config.mllib;

import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import oracle.pgx.common.types.PropertyType;
import oracle.pgx.common.util.ErrorMessages;

/* loaded from: input_file:oracle/pgx/config/mllib/GraphWiseModelConfig.class */
public abstract class GraphWiseModelConfig {
    public static final int DEFAULT_BATCH_SIZE = 128;
    public static final int DEFAULT_NUM_EPOCHS = 3;
    public static final double DEFAULT_LEARNING_RATE = 0.01d;
    public static final double DEFAULT_WEIGHT_DECAY = 0.0d;
    public static final int DEFAULT_EMBEDDING_DIM = 128;
    public static final boolean DEFAULT_STANDARDIZE = false;
    public static final boolean DEFAULT_SHUFFLE = true;
    private int batchSize;
    private int numEpochs;
    private double learningRate;
    private double weightDecay;
    private int embeddingDim;
    private Integer seed;
    private GraphWiseConvLayerConfig[] convLayerConfigs;
    private boolean standardize;
    private boolean shuffle;
    private Backend backend;
    private GraphConvModelVariant variant;
    private List<String> vertexInputPropertyNames;
    private List<String> edgeInputPropertyNames;
    private List<Set<String>> targetVertexLabelSets;
    private boolean fitted;
    private double trainingLoss;
    private int vertexInputFeatureDim;
    private int edgeInputFeatureDim;
    public static final EnumSet<PropertyType> SUPPORTED_INPUT_TYPES = EnumSet.of(PropertyType.INTEGER, PropertyType.LONG, PropertyType.FLOAT, PropertyType.DOUBLE, PropertyType.BOOLEAN);
    public static final Integer DEFAULT_SEED = null;
    public static final Backend DEFAULT_BACKEND = Backend.TORCH;
    public static final GraphConvModelVariant DEFAULT_MODE = null;
    public static final GraphWiseConvLayerConfig[] DEFAULT_CONV_LAYER_CONFIGS = {new GraphWiseConvLayerConfig(), new GraphWiseConvLayerConfig()};

    /* loaded from: input_file:oracle/pgx/config/mllib/GraphWiseModelConfig$Backend.class */
    public enum Backend {
        TORCH
    }

    /* loaded from: input_file:oracle/pgx/config/mllib/GraphWiseModelConfig$GraphConvModelVariant.class */
    public enum GraphConvModelVariant {
        GRAPHWISE,
        INTERTWINED
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphWiseModelConfig() {
        this.batchSize = 128;
        this.numEpochs = 3;
        this.learningRate = 0.01d;
        this.weightDecay = 0.0d;
        this.embeddingDim = 128;
        this.seed = DEFAULT_SEED;
        this.convLayerConfigs = DEFAULT_CONV_LAYER_CONFIGS;
        this.standardize = false;
        this.shuffle = true;
        this.backend = DEFAULT_BACKEND;
        this.variant = DEFAULT_MODE;
        this.fitted = false;
    }

    public GraphWiseModelConfig(int i, int i2, double d, double d2, int i3, Integer num, GraphWiseConvLayerConfig[] graphWiseConvLayerConfigArr, boolean z, boolean z2, List<String> list, List<String> list2, boolean z3, double d3, int i4, int i5, List<Set<String>> list3, Backend backend, GraphConvModelVariant graphConvModelVariant) {
        this.batchSize = 128;
        this.numEpochs = 3;
        this.learningRate = 0.01d;
        this.weightDecay = 0.0d;
        this.embeddingDim = 128;
        this.seed = DEFAULT_SEED;
        this.convLayerConfigs = DEFAULT_CONV_LAYER_CONFIGS;
        this.standardize = false;
        this.shuffle = true;
        this.backend = DEFAULT_BACKEND;
        this.variant = DEFAULT_MODE;
        this.fitted = false;
        this.batchSize = i;
        this.numEpochs = i2;
        this.learningRate = d;
        this.weightDecay = d2;
        this.embeddingDim = i3;
        this.seed = num;
        this.convLayerConfigs = graphWiseConvLayerConfigArr;
        this.standardize = z;
        this.shuffle = z2;
        this.vertexInputPropertyNames = list;
        this.edgeInputPropertyNames = list2;
        this.fitted = z3;
        this.trainingLoss = d3;
        this.vertexInputFeatureDim = i4;
        this.edgeInputFeatureDim = i5;
        this.targetVertexLabelSets = list3;
        this.backend = backend;
        this.variant = graphConvModelVariant;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphWiseModelConfig(GraphWiseModelConfig graphWiseModelConfig) {
        this.batchSize = 128;
        this.numEpochs = 3;
        this.learningRate = 0.01d;
        this.weightDecay = 0.0d;
        this.embeddingDim = 128;
        this.seed = DEFAULT_SEED;
        this.convLayerConfigs = DEFAULT_CONV_LAYER_CONFIGS;
        this.standardize = false;
        this.shuffle = true;
        this.backend = DEFAULT_BACKEND;
        this.variant = DEFAULT_MODE;
        this.fitted = false;
        setBatchSize(graphWiseModelConfig.getBatchSize());
        setNumEpochs(graphWiseModelConfig.getNumEpochs());
        setLearningRate(graphWiseModelConfig.getLearningRate());
        setEmbeddingDim(graphWiseModelConfig.getEmbeddingDim());
        this.seed = graphWiseModelConfig.getSeed();
        setFitted(graphWiseModelConfig.isFitted());
        setShuffle(graphWiseModelConfig.isShuffle());
        setTrainingLoss(graphWiseModelConfig.getTrainingLoss());
        setInputFeatureDim(graphWiseModelConfig.getInputFeatureDim());
        setEdgeInputFeatureDim(graphWiseModelConfig.getEdgeInputFeatureDim());
        setTargetVertexLabelSets(graphWiseModelConfig.getTargetVertexLabelSets());
        setStandardize(graphWiseModelConfig.isStandardize());
        this.variant = graphWiseModelConfig.getVariant();
        GraphWiseConvLayerConfig[] convLayerConfigs = graphWiseModelConfig.getConvLayerConfigs();
        GraphWiseConvLayerConfig[] graphWiseConvLayerConfigArr = new GraphWiseConvLayerConfig[convLayerConfigs.length];
        for (int i = 0; i < convLayerConfigs.length; i++) {
            graphWiseConvLayerConfigArr[i] = new GraphWiseConvLayerConfig();
            graphWiseConvLayerConfigArr[i].setNumSampledNeighbors(convLayerConfigs[i].getNumSampledNeighbors());
            graphWiseConvLayerConfigArr[i].setWeightedAggregationProperty(convLayerConfigs[i].getNeighborWeightPropertyName());
            graphWiseConvLayerConfigArr[i].setActivationFunction(convLayerConfigs[i].getActivationFunction());
            graphWiseConvLayerConfigArr[i].setWeightInitScheme(convLayerConfigs[i].getWeightInitScheme());
            boolean booleanValue = convLayerConfigs[i].getVertexToVertexConnection() == null ? true : convLayerConfigs[i].getVertexToVertexConnection().booleanValue();
            boolean z = convLayerConfigs[i].getEdgeToVertexConnection() != null && convLayerConfigs[i].getEdgeToVertexConnection().booleanValue();
            boolean z2 = convLayerConfigs[i].getVertexToEdgeConnection() != null && convLayerConfigs[i].getVertexToEdgeConnection().booleanValue();
            boolean z3 = convLayerConfigs[i].getEdgeToEdgeConnection() != null && convLayerConfigs[i].getEdgeToEdgeConnection().booleanValue();
            graphWiseConvLayerConfigArr[i].useVertexToVertexConnection(booleanValue);
            graphWiseConvLayerConfigArr[i].useVertexToEdgeConnection(z2);
            graphWiseConvLayerConfigArr[i].useEdgeToVertexConnection(z);
            graphWiseConvLayerConfigArr[i].useEdgeToEdgeConnection(z3);
        }
        setConvLayerConfigs(graphWiseConvLayerConfigArr);
        if (graphWiseModelConfig.getVertexInputPropertyNames() != null) {
            this.vertexInputPropertyNames = new ArrayList(graphWiseModelConfig.getVertexInputPropertyNames());
        }
        if (graphWiseModelConfig.getEdgeInputPropertyNames() != null) {
            this.edgeInputPropertyNames = new ArrayList(graphWiseModelConfig.getEdgeInputPropertyNames());
        }
    }

    public boolean isShuffle() {
        return this.shuffle;
    }

    public final void setShuffle(boolean z) {
        this.shuffle = z;
    }

    public int getInputFeatureDim() {
        return this.vertexInputFeatureDim;
    }

    public final void setInputFeatureDim(int i) {
        this.vertexInputFeatureDim = i;
    }

    public int getEdgeInputFeatureDim() {
        return this.edgeInputFeatureDim;
    }

    public final void setEdgeInputFeatureDim(int i) {
        this.edgeInputFeatureDim = i;
    }

    public boolean isFitted() {
        return this.fitted;
    }

    public final void setFitted(boolean z) {
        this.fitted = z;
    }

    public double getTrainingLoss() {
        return this.trainingLoss;
    }

    public final void setTrainingLoss(double d) {
        this.trainingLoss = d;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public final void setBatchSize(int i) {
        this.batchSize = i;
    }

    public int getNumEpochs() {
        return this.numEpochs;
    }

    public final void setNumEpochs(int i) {
        this.numEpochs = i;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public final void setLearningRate(double d) {
        this.learningRate = d;
    }

    public double getWeightDecay() {
        return this.weightDecay;
    }

    public final void setWeightDecay(double d) {
        this.weightDecay = d;
    }

    public int getEmbeddingDim() {
        return this.embeddingDim;
    }

    public final void setEmbeddingDim(int i) {
        this.embeddingDim = i;
    }

    public Integer getSeed() {
        return this.seed;
    }

    public final void setSeed(int i) {
        this.seed = Integer.valueOf(i);
    }

    public GraphWiseConvLayerConfig[] getConvLayerConfigs() {
        return this.convLayerConfigs;
    }

    public final void setConvLayerConfigs(GraphWiseConvLayerConfig... graphWiseConvLayerConfigArr) {
        this.convLayerConfigs = graphWiseConvLayerConfigArr;
    }

    public List<String> getVertexInputPropertyNames() {
        return this.vertexInputPropertyNames;
    }

    public final void setVertexInputPropertyNames(List<String> list) {
        this.vertexInputPropertyNames = list;
    }

    public List<String> getEdgeInputPropertyNames() {
        return this.edgeInputPropertyNames;
    }

    public final void setEdgeInputPropertyNames(List<String> list) {
        this.edgeInputPropertyNames = list;
    }

    public List<Set<String>> getTargetVertexLabelSets() {
        return this.targetVertexLabelSets;
    }

    public final void setTargetVertexLabelSets(List<Set<String>> list) {
        this.targetVertexLabelSets = list;
    }

    public void setTargetVertexLabels(List<String> list) {
        this.targetVertexLabelSets = listOfStringsToListOfSetOfStrings(list);
    }

    public boolean isStandardize() {
        return this.standardize;
    }

    public final void setStandardize(boolean z) {
        this.standardize = z;
    }

    public Backend getBackend() {
        return this.backend;
    }

    public final void setVariant(GraphConvModelVariant graphConvModelVariant) {
        if (this.variant != null) {
            throw new IllegalStateException(ErrorMessages.getMessage("IMMUTABLE_GRAPHWISE_VARIANT", new Object[0]));
        }
        this.variant = graphConvModelVariant;
    }

    public GraphConvModelVariant getVariant() {
        return this.variant;
    }

    private static List<Set<String>> listOfStringsToListOfSetOfStrings(List<String> list) {
        return (List) list.stream().map(str -> {
            return new HashSet(Collections.singletonList(str));
        }).collect(Collectors.toList());
    }
}
