package oracle.pgx.api.mllib;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.PgxVertex;
import oracle.pgx.api.VertexProperty;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.GnnExplanationMetaData;
import oracle.pgx.api.internal.mllib.GraphWiseModelMetadata;
import oracle.pgx.api.mllib.GraphWiseModel;
import oracle.pgx.common.PgxId;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWiseModelConfig;

/* loaded from: input_file:oracle/pgx/api/mllib/GraphWiseModel.class */
public abstract class GraphWiseModel<Config extends GraphWiseModelConfig, Metadata extends GraphWiseModelMetadata<Config>, ModelType extends GraphWiseModel<Config, Metadata, ModelType>> extends Model<ModelType> {
    Metadata modelMetadata;
    protected final BiFunction<PgxSession, Graph, PgxGraph> graphConstructor;

    public GraphWiseModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, Metadata metadata, BiFunction<PgxSession, Graph, PgxGraph> biFunction) {
        super(pgxSession, core, supplier, supplier2);
        this.modelMetadata = metadata;
        this.graphConstructor = biFunction;
    }

    @Override // oracle.pgx.api.mllib.Model
    String getModelName() {
        return this.modelMetadata.getModelName();
    }

    @Override // oracle.pgx.api.mllib.Model
    public PgxFuture<Void> destroyAsync() {
        return this.core.destroyMlModel(this.session.getSessionContext(), this.modelMetadata.getModelName());
    }

    public void destroy() throws ExecutionException, InterruptedException {
        destroyAsync().get();
    }

    public int getNumEpochs() {
        return this.modelMetadata.getConfig().getNumEpochs();
    }

    public double getLearningRate() {
        return this.modelMetadata.getConfig().getLearningRate();
    }

    public int getBatchSize() {
        return getConfig().getBatchSize();
    }

    public int getEmbeddingDim() {
        return getConfig().getEmbeddingDim();
    }

    public int getSeed() {
        return getConfig().getSeed().intValue();
    }

    public GraphWiseConvLayerConfig[] getConvLayerConfigs() {
        return getConfig().getConvLayerConfigs();
    }

    public List<String> getVertexInputPropertyNames() {
        return getConfig().getVertexInputPropertyNames();
    }

    public List<String> getEdgeInputPropertyNames() {
        return getConfig().getEdgeInputPropertyNames();
    }

    public boolean isFitted() {
        return getConfig().isFitted();
    }

    public double getTrainingLoss() {
        return getConfig().getTrainingLoss();
    }

    public int getInputFeatureDim() {
        return getConfig().getInputFeatureDim();
    }

    public int getEdgeInputFeatureDim() {
        return getConfig().getEdgeInputFeatureDim();
    }

    public Config getConfig() {
        return (Config) this.modelMetadata.getConfig();
    }

    public abstract PgxFuture<Double> fitAsync(PgxGraph pgxGraph);

    public double fit(PgxGraph pgxGraph) throws ExecutionException, InterruptedException {
        return fitAsync(pgxGraph).get().doubleValue();
    }

    public abstract <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable);

    public <ID> PgxFrame inferEmbeddings(PgxGraph pgxGraph, Iterable<PgxVertex<ID>> iterable) {
        return inferEmbeddingsAsync(pgxGraph, iterable).join();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <ID> List<Object> serializeVertices(Iterable<PgxVertex<ID>> iterable) {
        ArrayList arrayList = new ArrayList();
        iterable.forEach(pgxVertex -> {
            arrayList.add(pgxVertex.serialize());
        });
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <ID> GnnExplanation<ID> processExplanationResult(PgxGraph pgxGraph, GnnExplanationMetaData gnnExplanationMetaData) {
        PgxGraph apply = this.graphConstructor.apply(this.session, gnnExplanationMetaData.getImportanceGraph());
        PgxId vertexImportancePropertyId = gnnExplanationMetaData.getVertexImportancePropertyId();
        VertexProperty<?, ?> orElseThrow = apply.getVertexProperties().stream().filter(vertexProperty -> {
            return vertexProperty.getPropertyId().equals(vertexImportancePropertyId);
        }).findAny().orElseThrow(IllegalStateException::new);
        HashMap hashMap = new HashMap();
        Map<PgxId, Float> vertexFeatureImportances = gnnExplanationMetaData.getVertexFeatureImportances();
        for (VertexProperty<?, ?> vertexProperty2 : pgxGraph.getVertexProperties()) {
            if (vertexFeatureImportances.containsKey(vertexProperty2.getPropertyId())) {
                hashMap.put(vertexProperty2, vertexFeatureImportances.get(vertexProperty2.getPropertyId()));
            }
        }
        return new GnnExplanation<>(hashMap, apply, orElseThrow, gnnExplanationMetaData.getEmbedding());
    }
}
