package oracle.pgx.api.mllib;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata;
import oracle.pgx.common.types.PropertyType;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig;
import oracle.pgx.config.mllib.batchgenerator.BatchGenerator;
import oracle.pgx.config.mllib.loss.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:oracle/pgx/api/mllib/SupervisedGraphWiseModelBuilder.class */
public class SupervisedGraphWiseModelBuilder extends GraphWiseModelBuilder<SupervisedGraphWiseModel, SupervisedGraphWiseModelConfig, SupervisedGraphWiseModelBuilder> {
    static final Logger LOG = LoggerFactory.getLogger(SupervisedGraphWiseModelBuilder.class);
    private final PgxSession session;
    private final Core core;
    private final Supplier<String> keystorePathSupplier;
    private final Supplier<char[]> keystorePasswordSupplier;
    private final BiFunction<PgxSession, Graph, PgxGraph> graphConstructor;

    public SupervisedGraphWiseModelBuilder(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, BiFunction<PgxSession, Graph, PgxGraph> biFunction) {
        this.session = pgxSession;
        this.core = core;
        this.keystorePathSupplier = supplier;
        this.keystorePasswordSupplier = supplier2;
        this.modelConfig = new SupervisedGraphWiseModelConfig();
        this.graphConstructor = biFunction;
    }

    public SupervisedGraphWiseModelBuilder setVertexTargetPropertyName(String str) {
        this.modelConfig.setVertexTargetPropertyName(str);
        return this;
    }

    public SupervisedGraphWiseModelBuilder setPredictionLayerConfigs(GraphWisePredictionLayerConfig... graphWisePredictionLayerConfigArr) {
        this.modelConfig.setPredictionLayerConfigs(graphWisePredictionLayerConfigArr);
        return this;
    }

    public SupervisedGraphWiseModelBuilder setClassWeights(Map<?, Float> map) {
        this.modelConfig.setClassWeights(map);
        return this;
    }

    @Deprecated
    public SupervisedGraphWiseModelBuilder setLossFunction(SupervisedGraphWiseModelConfig.LossFunction lossFunction) {
        this.modelConfig.setLossFunction(lossFunction);
        return this;
    }

    public SupervisedGraphWiseModelBuilder setLossFunction(LossFunction lossFunction) {
        this.modelConfig.setLossFunctionClass(lossFunction);
        return this;
    }

    public SupervisedGraphWiseModelBuilder setBatchGenerator(BatchGenerator batchGenerator) {
        this.modelConfig.setBatchGenerator(batchGenerator);
        return this;
    }

    public SupervisedGraphWiseModelBuilder setNormalize(boolean z) {
        this.modelConfig.setNormalize(z);
        return this;
    }

    private void validatePredLayerConfigs() {
        if (this.modelConfig.getPredictionLayerConfigs() != null && this.modelConfig.getPredictionLayerConfigs().length == 0) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("INVALID_PREDICTION_LAYERS", new Object[0]));
        }
    }

    private void validateClassWeights() {
        if (this.modelConfig.getClassWeights() != null) {
            if (this.modelConfig.getClassWeights().size() == 0) {
                throw new IllegalArgumentException(ErrorMessages.getMessage("EMPTY_CLASS_WEIGHTS", new Object[0]));
            }
            Object next = this.modelConfig.getClassWeights().keySet().iterator().next();
            if (next instanceof String) {
                this.modelConfig.setLabelType(PropertyType.STRING);
            } else if (next instanceof Integer) {
                this.modelConfig.setLabelType(PropertyType.INTEGER);
            } else if (next instanceof Long) {
                this.modelConfig.setLabelType(PropertyType.LONG);
            } else {
                if (!(next instanceof Boolean)) {
                    throw new IllegalArgumentException(ErrorMessages.getMessage("UNSUPPORTED_CLASS_WEIGHTS_KEY_TYPE", new Object[]{next.getClass()}));
                }
                this.modelConfig.setLabelType(PropertyType.BOOLEAN);
            }
            Iterator it = this.modelConfig.getClassWeights().entrySet().iterator();
            while (it.hasNext()) {
                if (((Map.Entry) it.next()).getKey().getClass() != next.getClass()) {
                    throw new IllegalArgumentException(ErrorMessages.getMessage("INCONSISTENT_CLASS_WEIGHTS_KEY_TYPES", new Object[0]));
                }
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // oracle.pgx.api.mllib.GraphWiseModelBuilder
    public SupervisedGraphWiseModel build() throws InterruptedException, ExecutionException {
        super.validateAll();
        validateClassWeights();
        validatePredLayerConfigs();
        if (this.modelConfig.getVertexTargetPropertyName() == null) {
            throw new IllegalArgumentException(ErrorMessages.getMessage("NO_TARGET_PROPERTY", new Object[0]));
        }
        LOG.debug("Building Model...");
        return (SupervisedGraphWiseModel) this.core.createSupervisedGraphWiseModel(this.session.getSessionContext(), new SupervisedGraphWiseModelMetadata(null, this.modelConfig)).thenApply(supervisedGraphWiseModelMetadata -> {
            return new SupervisedGraphWiseModel(this.session, this.core, this.keystorePathSupplier, this.keystorePasswordSupplier, this.graphConstructor, supervisedGraphWiseModelMetadata);
        }).get();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // oracle.pgx.api.mllib.GraphWiseModelBuilder
    public SupervisedGraphWiseModelBuilder getThis() {
        return this;
    }
}
