package oracle.pgx.config.mllib;

import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import oracle.pgx.common.IllegalEnumConstantException;
import oracle.pgx.common.types.PropertyType;
import oracle.pgx.config.internal.categorymapping.CategoryMappingConfig;
import oracle.pgx.config.mllib.inputconfig.CategoricalEmbeddingType;
import oracle.pgx.config.mllib.inputconfig.CategoricalPropertyConfig;
import oracle.pgx.config.mllib.inputconfig.ContinuousPropertyConfig;
import oracle.pgx.config.mllib.inputconfig.EmbeddingTableConfig;
import oracle.pgx.config.mllib.inputconfig.InputPropertyConfig;
import oracle.pgx.config.mllib.inputconfig.OneHotEncodingConfig;
import oracle.pgx.config.mllib.loss.DevNetLoss;

/* loaded from: input_file:oracle/pgx/config/mllib/GraphWiseBaseModelConfig.class */
public abstract class GraphWiseBaseModelConfig {
    public static final int DEFAULT_BATCH_SIZE = 128;
    public static final int DEFAULT_NUM_EPOCHS = 3;
    public static final double DEFAULT_LEARNING_RATE = 0.01d;
    public static final double DEFAULT_WEIGHT_DECAY = 0.0d;
    public static final int DEFAULT_EMBEDDING_DIM = 128;
    public static final boolean DEFAULT_STANDARDIZE = false;
    public static final boolean DEFAULT_NORMALIZE = true;
    public static final boolean DEFAULT_SHUFFLE = true;
    public static final boolean DEFAULT_ENABLE_ACCELERATOR = true;
    private int batchSize;
    private int numEpochs;
    private double learningRate;
    private double weightDecay;
    private int embeddingDim;
    private Integer seed;
    private GraphWiseBaseConvLayerConfig[] convLayerConfigs;
    private final Map<String, InputPropertyConfig> vertexInputPropertyConfigs;
    private final Map<String, InputPropertyConfig> edgeInputPropertyConfigs;
    protected CategoryMappingConfig categoryMappingConfig;
    private boolean standardize;
    private boolean normalize;
    private boolean shuffle;
    private Backend backend;
    private boolean enableAccelerator;
    private List<String> vertexInputPropertyNames;
    private List<String> edgeInputPropertyNames;
    private boolean fitted;
    private double trainingLoss;
    private int vertexInputFeatureDim;
    private int edgeInputFeatureDim;
    public static final EnumSet<PropertyType> SUPPORTED_INPUT_TYPES = EnumSet.of(PropertyType.INTEGER, PropertyType.LONG, PropertyType.FLOAT, PropertyType.DOUBLE, PropertyType.BOOLEAN, PropertyType.STRING);
    public static final Integer DEFAULT_SEED = null;
    public static final Backend DEFAULT_BACKEND = Backend.TORCH;
    public static final GraphWiseBaseConvLayerConfig[] DEFAULT_CONV_LAYER_CONFIGS = {new GraphWiseConvLayerConfig(), new GraphWiseConvLayerConfig()};

    /* renamed from: oracle.pgx.config.mllib.GraphWiseBaseModelConfig$1, reason: invalid class name */
    /* loaded from: input_file:oracle/pgx/config/mllib/GraphWiseBaseModelConfig$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$oracle$pgx$config$mllib$inputconfig$CategoricalEmbeddingType = new int[CategoricalEmbeddingType.values().length];

        static {
            try {
                $SwitchMap$oracle$pgx$config$mllib$inputconfig$CategoricalEmbeddingType[CategoricalEmbeddingType.EMBEDDING_TABLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$oracle$pgx$config$mllib$inputconfig$CategoricalEmbeddingType[CategoricalEmbeddingType.ONE_HOT_ENCODING.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:oracle/pgx/config/mllib/GraphWiseBaseModelConfig$Backend.class */
    public enum Backend {
        TORCH
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public GraphWiseBaseModelConfig() {
        this.batchSize = 128;
        this.numEpochs = 3;
        this.learningRate = 0.01d;
        this.weightDecay = 0.0d;
        this.embeddingDim = 128;
        this.seed = DEFAULT_SEED;
        this.convLayerConfigs = DEFAULT_CONV_LAYER_CONFIGS;
        this.standardize = false;
        this.normalize = true;
        this.shuffle = true;
        this.backend = DEFAULT_BACKEND;
        this.enableAccelerator = true;
        this.fitted = false;
        this.vertexInputPropertyConfigs = new HashMap();
        this.edgeInputPropertyConfigs = new HashMap();
    }

    public GraphWiseBaseModelConfig(int i, int i2, double d, double d2, int i3, Integer num, GraphWiseBaseConvLayerConfig[] graphWiseBaseConvLayerConfigArr, boolean z, boolean z2, boolean z3, List<String> list, List<String> list2, Map<String, InputPropertyConfig> map, Map<String, InputPropertyConfig> map2, CategoryMappingConfig categoryMappingConfig, boolean z4, double d3, int i4, int i5, Backend backend, boolean z5) {
        this.batchSize = 128;
        this.numEpochs = 3;
        this.learningRate = 0.01d;
        this.weightDecay = 0.0d;
        this.embeddingDim = 128;
        this.seed = DEFAULT_SEED;
        this.convLayerConfigs = DEFAULT_CONV_LAYER_CONFIGS;
        this.standardize = false;
        this.normalize = true;
        this.shuffle = true;
        this.backend = DEFAULT_BACKEND;
        this.enableAccelerator = true;
        this.fitted = false;
        this.batchSize = i;
        this.numEpochs = i2;
        this.learningRate = d;
        this.weightDecay = d2;
        this.embeddingDim = i3;
        this.seed = num;
        this.convLayerConfigs = graphWiseBaseConvLayerConfigArr;
        this.standardize = z;
        this.shuffle = z3;
        this.vertexInputPropertyNames = list;
        this.edgeInputPropertyNames = list2;
        this.vertexInputPropertyConfigs = map != null ? map : new HashMap<>();
        this.edgeInputPropertyConfigs = map2 != null ? map2 : new HashMap<>();
        this.categoryMappingConfig = categoryMappingConfig;
        this.fitted = z4;
        this.trainingLoss = d3;
        this.vertexInputFeatureDim = i4;
        this.edgeInputFeatureDim = i5;
        this.backend = backend;
        this.normalize = z2;
        this.enableAccelerator = z5;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public GraphWiseBaseModelConfig(GraphWiseBaseModelConfig graphWiseBaseModelConfig) {
        this.batchSize = 128;
        this.numEpochs = 3;
        this.learningRate = 0.01d;
        this.weightDecay = 0.0d;
        this.embeddingDim = 128;
        this.seed = DEFAULT_SEED;
        this.convLayerConfigs = DEFAULT_CONV_LAYER_CONFIGS;
        this.standardize = false;
        this.normalize = true;
        this.shuffle = true;
        this.backend = DEFAULT_BACKEND;
        this.enableAccelerator = true;
        this.fitted = false;
        setBatchSize(graphWiseBaseModelConfig.getBatchSize());
        setNumEpochs(graphWiseBaseModelConfig.getNumEpochs());
        setLearningRate(graphWiseBaseModelConfig.getLearningRate());
        setEmbeddingDim(graphWiseBaseModelConfig.getEmbeddingDim());
        this.seed = graphWiseBaseModelConfig.getSeed();
        setFitted(graphWiseBaseModelConfig.isFitted());
        setShuffle(graphWiseBaseModelConfig.isShuffle());
        setTrainingLoss(graphWiseBaseModelConfig.getTrainingLoss());
        setInputFeatureDim(graphWiseBaseModelConfig.getInputFeatureDim());
        setEdgeInputFeatureDim(graphWiseBaseModelConfig.getEdgeInputFeatureDim());
        setStandardize(graphWiseBaseModelConfig.isStandardize());
        setNormalize(graphWiseBaseModelConfig.isNormalize());
        setEnableAccelerator(graphWiseBaseModelConfig.isEnableAccelerator());
        GraphWiseBaseConvLayerConfig[] convLayerConfigs = graphWiseBaseModelConfig.getConvLayerConfigs();
        GraphWiseBaseConvLayerConfig[] graphWiseBaseConvLayerConfigArr = new GraphWiseBaseConvLayerConfig[convLayerConfigs.length];
        for (int i = 0; i < convLayerConfigs.length; i++) {
            if (convLayerConfigs[i] instanceof GraphWiseConvLayerConfig) {
                graphWiseBaseConvLayerConfigArr[i] = new GraphWiseConvLayerConfig((GraphWiseConvLayerConfig) convLayerConfigs[i]);
            } else {
                if (!(convLayerConfigs[i] instanceof GraphWiseAttentionLayerConfig)) {
                    throw new IllegalArgumentException("Unsupported type of convolutional layer config type.");
                }
                graphWiseBaseConvLayerConfigArr[i] = new GraphWiseAttentionLayerConfig((GraphWiseAttentionLayerConfig) convLayerConfigs[i]);
            }
        }
        setConvLayerConfigs(graphWiseBaseConvLayerConfigArr);
        if (graphWiseBaseModelConfig.getVertexInputPropertyNames() != null) {
            this.vertexInputPropertyNames = new ArrayList(graphWiseBaseModelConfig.getVertexInputPropertyNames());
        }
        if (graphWiseBaseModelConfig.getEdgeInputPropertyNames() != null) {
            this.edgeInputPropertyNames = new ArrayList(graphWiseBaseModelConfig.getEdgeInputPropertyNames());
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry<String, InputPropertyConfig> entry : graphWiseBaseModelConfig.getVertexInputPropertyConfigs().entrySet()) {
            String key = entry.getKey();
            InputPropertyConfig value = entry.getValue();
            if (value.getCategorical()) {
                CategoricalEmbeddingType categoricalEmbeddingType = ((CategoricalPropertyConfig) value).getCategoricalEmbeddingType();
                switch (AnonymousClass1.$SwitchMap$oracle$pgx$config$mllib$inputconfig$CategoricalEmbeddingType[categoricalEmbeddingType.ordinal()]) {
                    case 1:
                        hashMap.put(key, new EmbeddingTableConfig((EmbeddingTableConfig) value));
                        break;
                    case DevNetLoss.EXPECTED_CLASSES /* 2 */:
                        hashMap.put(key, new OneHotEncodingConfig((OneHotEncodingConfig) value));
                        break;
                    default:
                        throw new IllegalEnumConstantException(categoricalEmbeddingType);
                }
            } else {
                hashMap.put(key, new ContinuousPropertyConfig((ContinuousPropertyConfig) value));
            }
        }
        for (Map.Entry<String, InputPropertyConfig> entry2 : graphWiseBaseModelConfig.getEdgeInputPropertyConfigs().entrySet()) {
            String key2 = entry2.getKey();
            InputPropertyConfig value2 = entry2.getValue();
            if (value2.getCategorical()) {
                CategoricalEmbeddingType categoricalEmbeddingType2 = ((CategoricalPropertyConfig) value2).getCategoricalEmbeddingType();
                switch (AnonymousClass1.$SwitchMap$oracle$pgx$config$mllib$inputconfig$CategoricalEmbeddingType[categoricalEmbeddingType2.ordinal()]) {
                    case 1:
                        hashMap2.put(key2, new EmbeddingTableConfig((EmbeddingTableConfig) value2));
                        break;
                    case DevNetLoss.EXPECTED_CLASSES /* 2 */:
                        hashMap2.put(key2, new OneHotEncodingConfig((OneHotEncodingConfig) value2));
                        break;
                    default:
                        throw new IllegalEnumConstantException(categoricalEmbeddingType2);
                }
            } else {
                hashMap2.put(key2, new ContinuousPropertyConfig((ContinuousPropertyConfig) value2));
            }
        }
        this.vertexInputPropertyConfigs = hashMap;
        this.edgeInputPropertyConfigs = hashMap2;
        if (graphWiseBaseModelConfig.getVertexInputPropertyNames() != null) {
            this.vertexInputPropertyNames = new ArrayList(graphWiseBaseModelConfig.getVertexInputPropertyNames());
        }
        if (graphWiseBaseModelConfig.getEdgeInputPropertyNames() != null) {
            this.edgeInputPropertyNames = new ArrayList(graphWiseBaseModelConfig.getEdgeInputPropertyNames());
        }
    }

    public boolean isShuffle() {
        return this.shuffle;
    }

    public final void setShuffle(boolean z) {
        this.shuffle = z;
    }

    public int getInputFeatureDim() {
        return this.vertexInputFeatureDim;
    }

    public final void setInputFeatureDim(int i) {
        this.vertexInputFeatureDim = i;
    }

    public int getEdgeInputFeatureDim() {
        return this.edgeInputFeatureDim;
    }

    public final void setEdgeInputFeatureDim(int i) {
        this.edgeInputFeatureDim = i;
    }

    public boolean isFitted() {
        return this.fitted;
    }

    public final void setFitted(boolean z) {
        this.fitted = z;
    }

    public double getTrainingLoss() {
        return this.trainingLoss;
    }

    public final void setTrainingLoss(double d) {
        this.trainingLoss = d;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public final void setBatchSize(int i) {
        this.batchSize = i;
    }

    public int getNumEpochs() {
        return this.numEpochs;
    }

    public final void setNumEpochs(int i) {
        this.numEpochs = i;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public final void setLearningRate(double d) {
        this.learningRate = d;
    }

    public double getWeightDecay() {
        return this.weightDecay;
    }

    public final void setWeightDecay(double d) {
        this.weightDecay = d;
    }

    public int getEmbeddingDim() {
        return this.embeddingDim;
    }

    public final void setEmbeddingDim(int i) {
        this.embeddingDim = i;
    }

    public Integer getSeed() {
        return this.seed;
    }

    public final void setSeed(int i) {
        this.seed = Integer.valueOf(i);
    }

    public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs() {
        return this.convLayerConfigs;
    }

    public final void setConvLayerConfigs(GraphWiseBaseConvLayerConfig... graphWiseBaseConvLayerConfigArr) {
        this.convLayerConfigs = graphWiseBaseConvLayerConfigArr;
    }

    public Map<String, InputPropertyConfig> getVertexInputPropertyConfigs() {
        return this.vertexInputPropertyConfigs;
    }

    public Map<String, InputPropertyConfig> getEdgeInputPropertyConfigs() {
        return this.edgeInputPropertyConfigs;
    }

    public final void setVertexInputPropertyConfigs(InputPropertyConfig... inputPropertyConfigArr) {
        this.vertexInputPropertyConfigs.clear();
        for (InputPropertyConfig inputPropertyConfig : inputPropertyConfigArr) {
            this.vertexInputPropertyConfigs.put(inputPropertyConfig.getPropertyName(), inputPropertyConfig);
        }
    }

    public final void setEdgeInputPropertyConfigs(InputPropertyConfig... inputPropertyConfigArr) {
        this.edgeInputPropertyConfigs.clear();
        for (InputPropertyConfig inputPropertyConfig : inputPropertyConfigArr) {
            this.edgeInputPropertyConfigs.put(inputPropertyConfig.getPropertyName(), inputPropertyConfig);
        }
    }

    public CategoryMappingConfig getCategoryMappingConfig() {
        return this.categoryMappingConfig;
    }

    public final void setCategoryMappingConfig(CategoryMappingConfig categoryMappingConfig) {
        this.categoryMappingConfig = categoryMappingConfig;
    }

    public final void clearSensitiveData() {
        this.categoryMappingConfig = null;
    }

    public List<String> getVertexInputPropertyNames() {
        return this.vertexInputPropertyNames;
    }

    public final void setVertexInputPropertyNames(List<String> list) {
        this.vertexInputPropertyNames = list;
    }

    public List<String> getEdgeInputPropertyNames() {
        return this.edgeInputPropertyNames;
    }

    public final void setEdgeInputPropertyNames(List<String> list) {
        this.edgeInputPropertyNames = list;
    }

    public boolean isStandardize() {
        return this.standardize;
    }

    public final void setStandardize(boolean z) {
        this.standardize = z;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    public final void setNormalize(boolean z) {
        this.normalize = z;
    }

    public Backend getBackend() {
        return this.backend;
    }

    public void setEnableAccelerator(boolean z) {
        this.enableAccelerator = z;
    }

    public boolean isEnableAccelerator() {
        return this.enableAccelerator;
    }
}
