package oracle.pgx.api.mllib;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.PgxVertex;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.frames.internal.PgxFrameImpl;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.ModelMetadata;
import oracle.pgx.api.internal.mllib.SupervisedGnnExplainerConfig;
import oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.ModelKind;
import oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig;
import oracle.pgx.config.mllib.loss.LossFunction;

/* loaded from: input_file:oracle/pgx/api/mllib/SupervisedGraphWiseModel.class */
public class SupervisedGraphWiseModel extends GraphWiseModel<SupervisedGraphWiseModelConfig, SupervisedGraphWiseModelMetadata, SupervisedGraphWiseModel> {
    public static final String ALGORITHM_NAME = "SupervisedGraphWise";

    /* loaded from: input_file:oracle/pgx/api/mllib/SupervisedGraphWiseModel$SupervisedGraphWiseInferenceType.class */
    public enum SupervisedGraphWiseInferenceType {
        INFER_EMBEDDINGS,
        INFER_LABELS,
        EVALUATE_LABELS,
        INFER_LOGITS
    }

    public SupervisedGraphWiseModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, BiFunction<PgxSession, Graph, PgxGraph> biFunction, SupervisedGraphWiseModelMetadata supervisedGraphWiseModelMetadata) {
        super(pgxSession, core, supplier, supplier2, supervisedGraphWiseModelMetadata, biFunction);
    }

    public SupervisedGraphWiseModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, BiFunction<PgxSession, Graph, PgxGraph> biFunction, ModelMetadata modelMetadata) {
        super(pgxSession, core, supplier, supplier2, null, biFunction);
        if (modelMetadata.getModelKind() != getModelKind()) {
            ErrorMessages.throwException(IllegalArgumentException::new, "UNEXPECTED_MODEL_KIND", new Object[]{getModelKind(), modelMetadata.getModelKind()});
        } else {
            this.modelMetadata = (SupervisedGraphWiseModelMetadata) modelMetadata;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // oracle.pgx.api.mllib.Model
    public SupervisedGraphWiseModel getThis() {
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // oracle.pgx.api.mllib.Model
    public ModelKind getModelKind() {
        return ModelKind.SUPERVISED_GRAPHWISE;
    }

    @Override // oracle.pgx.api.mllib.GraphWiseModel
    public PgxFuture<Double> fitAsync(PgxGraph pgxGraph) {
        return this.core.fitSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId()).thenApply(supervisedGraphWiseModelMetadata -> {
            this.modelMetadata = supervisedGraphWiseModelMetadata;
            return Double.valueOf(supervisedGraphWiseModelMetadata.getConfig().getTrainingLoss());
        });
    }

    @Override // oracle.pgx.api.mllib.GraphWiseModel
    public <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return this.core.inferEmbeddingsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), serializeVertices(iterable)).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public <ID> PgxFuture<PgxFrame> inferLogitsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return this.core.inferLogitsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), serializeVertices(iterable)).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public <ID> PgxFrame inferLogits(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return inferLogitsAsync(pgxGraph, iterable).join();
    }

    public <ID> PgxFuture<PgxFrame> inferLabelsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return inferLabelsAsync(pgxGraph, iterable, 0.0f);
    }

    public <ID> PgxFuture<PgxFrame> inferLabelsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable, float f) {
        return this.core.inferLabelsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), serializeVertices(iterable), f).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public <ID> PgxFuture<PgxFrame> evaluateLabelsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return evaluateLabelsAsync(pgxGraph, iterable, 0.0f);
    }

    public <ID> PgxFuture<PgxFrame> evaluateLabelsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable, float f) {
        return this.core.evaluateLabelsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), serializeVertices(iterable), f).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public <ID> PgxFrame inferLabels(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return inferLabelsAsync(pgxGraph, iterable, 0.0f).join();
    }

    public <ID> PgxFrame inferLabels(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable, float f) {
        return inferLabelsAsync(pgxGraph, iterable, f).join();
    }

    public <ID> PgxFrame evaluateLabels(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return evaluateLabelsAsync(pgxGraph, iterable, 0.0f).join();
    }

    public <ID> PgxFrame evaluateLabels(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable, float f) {
        return evaluateLabelsAsync(pgxGraph, iterable, f).join();
    }

    @Deprecated
    public <ID> PgxFuture<SupervisedGnnExplanation<ID>> inferAndGetExplanationAsync(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex) {
        return inferAndGetExplanationAsync(pgxGraph, pgxVertex, 0.0f);
    }

    @Deprecated
    public <ID> PgxFuture<SupervisedGnnExplanation<ID>> inferAndGetExplanationAsync(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex, float f) {
        return inferAndExplainWithConfigAsync(pgxGraph, pgxVertex, new SupervisedGnnExplainerConfig(), f);
    }

    @Deprecated
    public <ID> SupervisedGnnExplanation<ID> inferAndGetExplanation(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex) {
        return inferAndGetExplanationAsync(pgxGraph, pgxVertex, 0.0f).join();
    }

    @Deprecated
    public <ID> SupervisedGnnExplanation<ID> inferAndGetExplanation(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex, float f) {
        return inferAndGetExplanationAsync(pgxGraph, pgxVertex, f).join();
    }

    public PgxFuture<Void> storeAsync(String str, String str2) throws ExecutionException, InterruptedException {
        return storeAsync(str, str2, false);
    }

    public PgxFuture<Void> storeAsync(String str, String str2, boolean z) {
        return export().file().path(str).key(str2).overwrite(z).storeAsync();
    }

    public void store(String str, String str2) throws ExecutionException, InterruptedException {
        storeAsync(str, str2).get();
    }

    public void store(String str, String str2, boolean z) throws ExecutionException, InterruptedException {
        storeAsync(str, str2, z).get();
    }

    @Deprecated
    public SupervisedGraphWiseModelConfig.LossFunction getLossFunction() {
        return getConfig().getLossFunction();
    }

    public LossFunction getLossFunctionClass() {
        return getConfig().getLossFunctionClass();
    }

    public GraphWisePredictionLayerConfig[] getPredictionLayerConfigs() {
        return getConfig().getPredictionLayerConfigs();
    }

    public Map<?, Float> getClassWeights() {
        return getConfig().getClassWeights();
    }

    public String getVertexTargetPropertyName() {
        return getConfig().getVertexTargetPropertyName();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <ID> PgxFuture<SupervisedGnnExplanation<ID>> inferAndExplainWithConfigAsync(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex, SupervisedGnnExplainerConfig supervisedGnnExplainerConfig, float f) {
        return (PgxFuture<SupervisedGnnExplanation<ID>>) this.core.inferAndGetExplanationSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), pgxVertex.serialize(), f, supervisedGnnExplainerConfig).thenApply(supervisedGnnExplanationMetaData -> {
            GnnExplanation<ID> processExplanationResult = processExplanationResult(pgxGraph, supervisedGnnExplanationMetaData);
            return new SupervisedGnnExplanation(processExplanationResult.getVertexFeatureImportance(), processExplanationResult.getImportanceGraph(), processExplanationResult.getVertexImportanceProperty(), processExplanationResult.getEmbedding(), supervisedGnnExplanationMetaData.getLogits(), supervisedGnnExplanationMetaData.getLabel());
        });
    }

    public SupervisedGnnExplainer gnnExplainer() {
        return new SupervisedGnnExplainer(this);
    }

    public List<Set<String>> getTargetVertexLabels() {
        return getConfig().getTargetVertexLabelSets();
    }
}
