package oracle.pgx.api.mllib;

import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.PgxVertex;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.frames.internal.PgxFrameImpl;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.ModelMetadata;
import oracle.pgx.api.internal.mllib.UnsupervisedGnnExplainerConfig;
import oracle.pgx.api.internal.mllib.UnsupervisedGraphWiseModelMetadata;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWiseDgiLayerConfig;
import oracle.pgx.config.mllib.ModelKind;
import oracle.pgx.config.mllib.UnsupervisedGraphWiseModelConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:oracle/pgx/api/mllib/UnsupervisedGraphWiseModel.class */
public class UnsupervisedGraphWiseModel extends GraphWiseModel<UnsupervisedGraphWiseModelConfig, UnsupervisedGraphWiseModelMetadata, UnsupervisedGraphWiseModel> {
    public static final String ALGORITHM_NAME = "UnsupervisedGraphWise";
    private static final Logger LOG = LoggerFactory.getLogger(UnsupervisedGraphWiseModel.class);

    /* loaded from: input_file:oracle/pgx/api/mllib/UnsupervisedGraphWiseModel$UnsupervisedGraphWiseInferenceType.class */
    public enum UnsupervisedGraphWiseInferenceType {
        INFER_EMBEDDINGS
    }

    public UnsupervisedGraphWiseModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, UnsupervisedGraphWiseModelMetadata unsupervisedGraphWiseModelMetadata, BiFunction<PgxSession, Graph, PgxGraph> biFunction) {
        super(pgxSession, core, supplier, supplier2, unsupervisedGraphWiseModelMetadata, biFunction);
    }

    public UnsupervisedGraphWiseModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, ModelMetadata modelMetadata, BiFunction<PgxSession, Graph, PgxGraph> biFunction) {
        super(pgxSession, core, supplier, supplier2, null, biFunction);
        if (modelMetadata.getModelKind() != getModelKind()) {
            ErrorMessages.throwException(IllegalArgumentException::new, "UNEXPECTED_MODEL_KIND", new Object[]{getModelKind(), modelMetadata.getModelKind()});
        } else {
            this.modelMetadata = (UnsupervisedGraphWiseModelMetadata) modelMetadata;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // oracle.pgx.api.mllib.Model
    public UnsupervisedGraphWiseModel getThis() {
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // oracle.pgx.api.mllib.Model
    public ModelKind getModelKind() {
        return ModelKind.UNSUPERVISED_GRAPHWISE;
    }

    @Override // oracle.pgx.api.mllib.GraphWiseModel
    public PgxFuture<Double> fitAsync(PgxGraph pgxGraph) {
        return this.core.fitUnsupervisedGraphWiseModel(this.session.getSessionContext(), ((UnsupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId()).thenApply(unsupervisedGraphWiseModelMetadata -> {
            this.modelMetadata = unsupervisedGraphWiseModelMetadata;
            return Double.valueOf(unsupervisedGraphWiseModelMetadata.getConfig().getTrainingLoss());
        });
    }

    @Override // oracle.pgx.api.mllib.GraphWiseModel
    public <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return this.core.inferEmbeddingsUnsupervisedGraphWiseModel(this.session.getSessionContext(), ((UnsupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), serializeVertices(iterable)).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    @Deprecated
    public <ID> UnsupervisedGnnExplanation<ID> inferAndGetExplanation(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex, int i) {
        return inferAndGetExplanationAsync(pgxGraph, pgxVertex, i).join();
    }

    @Deprecated
    public <ID> UnsupervisedGnnExplanation<ID> inferAndGetExplanation(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex) {
        return inferAndGetExplanationAsync(pgxGraph, pgxVertex).join();
    }

    @Deprecated
    public <ID> PgxFuture<UnsupervisedGnnExplanation<ID>> inferAndGetExplanationAsync(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex) {
        return inferAndGetExplanationAsync(pgxGraph, pgxVertex, 50);
    }

    @Deprecated
    public <ID> PgxFuture<UnsupervisedGnnExplanation<ID>> inferAndGetExplanationAsync(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex, int i) {
        UnsupervisedGnnExplainerConfig unsupervisedGnnExplainerConfig = new UnsupervisedGnnExplainerConfig();
        unsupervisedGnnExplainerConfig.setNumClusters(i);
        return inferAndExplainWithConfigAsync(pgxGraph, pgxVertex, unsupervisedGnnExplainerConfig);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <ID> PgxFuture<UnsupervisedGnnExplanation<ID>> inferAndExplainWithConfigAsync(PgxGraph pgxGraph, PgxVertex<ID> pgxVertex, UnsupervisedGnnExplainerConfig unsupervisedGnnExplainerConfig) {
        long numVertices = pgxGraph.getNumVertices();
        int numClusters = unsupervisedGnnExplainerConfig.getNumClusters();
        if (numClusters > numVertices) {
            LOG.warn("Setting number of clusters used for explanation to {} instead of {} as graph only has {} vertices.", new Object[]{Long.valueOf(numVertices), Integer.valueOf(numClusters), Long.valueOf(numVertices)});
            unsupervisedGnnExplainerConfig.setNumClusters((int) numVertices);
        }
        if (numClusters < 2) {
            throw new IllegalStateException(ErrorMessages.getMessage("NOT_ENOUGH_CLUSTERS_FOR_EXPLANATION", new Object[0]));
        }
        return (PgxFuture<UnsupervisedGnnExplanation<ID>>) this.core.inferAndGetExplanationUnsupervisedGraphWiseModel(this.session.getSessionContext(), ((UnsupervisedGraphWiseModelMetadata) this.modelMetadata).getModelName(), pgxGraph.getId(), pgxVertex.serialize(), unsupervisedGnnExplainerConfig).thenApply(gnnExplanationMetaData -> {
            GnnExplanation<ID> processExplanationResult = processExplanationResult(pgxGraph, gnnExplanationMetaData);
            return new UnsupervisedGnnExplanation(processExplanationResult.getVertexFeatureImportance(), processExplanationResult.getImportanceGraph(), processExplanationResult.getVertexImportanceProperty(), processExplanationResult.getEmbedding());
        });
    }

    public UnsupervisedGnnExplainer gnnExplainer() {
        return new UnsupervisedGnnExplainer(this);
    }

    public PgxFuture<Void> storeAsync(String str, String str2) throws ExecutionException, InterruptedException {
        return storeAsync(str, str2, false);
    }

    public PgxFuture<Void> storeAsync(String str, String str2, boolean z) {
        return export().file().path(str).key(str2).overwrite(z).storeAsync();
    }

    public void store(String str, String str2) throws ExecutionException, InterruptedException {
        storeAsync(str, str2).get();
    }

    public void store(String str, String str2, boolean z) throws ExecutionException, InterruptedException {
        storeAsync(str, str2, z).get();
    }

    public UnsupervisedGraphWiseModelConfig.LossFunction getLossFunction() {
        return getConfig().getLossFunction();
    }

    public GraphWiseDgiLayerConfig getDgiLayerConfigs() {
        return getConfig().getDgiLayerConfig();
    }
}
