package oracle.pgx.api.mllib;

import com.google.common.base.Functions;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.frames.internal.PgxFrameImpl;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.mllib.DeepWalkModelMetadata;
import oracle.pgx.api.internal.mllib.ModelMetadata;
import oracle.pgx.config.mllib.ModelKind;

/* loaded from: input_file:oracle/pgx/api/mllib/DeepWalkModel.class */
public class DeepWalkModel extends Model<DeepWalkModel> {
    private final DeepWalkModelMetadata modelMetadata;

    public DeepWalkModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, DeepWalkModelMetadata deepWalkModelMetadata) {
        super(pgxSession, core, supplier, supplier2);
        this.modelMetadata = deepWalkModelMetadata;
    }

    public DeepWalkModel(PgxSession pgxSession, Core core, Supplier<String> supplier, Supplier<char[]> supplier2, ModelMetadata modelMetadata) {
        super(pgxSession, core, supplier, supplier2);
        if (!(modelMetadata instanceof DeepWalkModelMetadata)) {
            throw new IllegalArgumentException("expected DeepWalkModelMetaData");
        }
        this.modelMetadata = (DeepWalkModelMetadata) modelMetadata;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // oracle.pgx.api.mllib.Model
    public DeepWalkModel getThis() {
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // oracle.pgx.api.mllib.Model
    public ModelKind getModelKind() {
        return ModelKind.DEEPWALK;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // oracle.pgx.api.mllib.Model
    public String getModelName() {
        return this.modelMetadata.getModelName();
    }

    public PgxFuture<Void> fitAsync(PgxGraph pgxGraph) {
        return this.core.fitDeepWalkModel(this.session.getSessionContext(), this.modelMetadata.getModelName(), pgxGraph.getId()).thenApply(d -> {
            this.modelMetadata.setLoss(d);
            return null;
        });
    }

    public void fit(PgxGraph pgxGraph) throws ExecutionException, InterruptedException {
        fitAsync(pgxGraph).get();
    }

    public PgxFuture<PgxFrame> getTrainedVertexVectorsAsync() {
        return this.core.getTrainedVertexVectorsDeepWalkModel(this.session.getSessionContext(), this.modelMetadata.getModelName()).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public PgxFrame getTrainedVertexVectors() throws ExecutionException, InterruptedException {
        return getTrainedVertexVectorsAsync().get();
    }

    public PgxFuture<PgxFrame> computeSimilarsAsync(Object obj, int i) {
        return this.core.computeSimilarsDeepWalkModel(this.session.getSessionContext(), this.modelMetadata.getModelName(), String.valueOf(obj), i).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public PgxFrame computeSimilars(Object obj, int i) throws ExecutionException, InterruptedException {
        return computeSimilarsAsync(obj, i).get();
    }

    public PgxFuture<PgxFrame> computeSimilarsAsync(List<Object> list, int i) {
        return this.core.computeSimilarsBatchedDeepWalkModel(this.session.getSessionContext(), this.modelMetadata.getModelName(), Lists.transform(list, Functions.toStringFunction()), i).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData, this.keystorePathSupplier, this.keystorePasswordSupplier);
        });
    }

    public PgxFrame computeSimilars(List<Object> list, int i) throws ExecutionException, InterruptedException {
        return computeSimilarsAsync(list, i).get();
    }

    public PgxFuture<Void> storeAsync(String str, String str2) {
        return export().file().path(str).key(str2).storeAsync();
    }

    public PgxFuture<Void> storeAsync(String str, String str2, boolean z) {
        return export().file().path(str).key(str2).overwrite(z).storeAsync();
    }

    public void store(String str, String str2) throws ExecutionException, InterruptedException {
        storeAsync(str, str2).get();
    }

    public void store(String str, String str2, boolean z) throws ExecutionException, InterruptedException {
        storeAsync(str, str2, z).get();
    }

    @Override // oracle.pgx.api.mllib.Model
    public PgxFuture<Void> destroyAsync() {
        return this.core.destroyMlModel(this.session.getSessionContext(), this.modelMetadata.getModelName());
    }

    public void destroy() throws ExecutionException, InterruptedException {
        destroyAsync().get();
    }

    public int getNegativeSample() {
        return this.modelMetadata.getNegativeSample();
    }

    public double getSampleRate() {
        return this.modelMetadata.getSampleRate();
    }

    public int getMinWordFrequency() {
        return this.modelMetadata.getMinWordFrequency();
    }

    public int getBatchSize() {
        return this.modelMetadata.getBatchSize();
    }

    public int getNumEpochs() {
        return this.modelMetadata.getNumEpochs();
    }

    public int getLayerSize() {
        return this.modelMetadata.getLayerSize();
    }

    public double getLearningRate() {
        return this.modelMetadata.getLearningRate();
    }

    public double getMinLearningRate() {
        return this.modelMetadata.getMinLearningRate();
    }

    public int getWindowSize() {
        return this.modelMetadata.getWindowSize();
    }

    public int getWalkLength() {
        return this.modelMetadata.getWalkLength();
    }

    public int getWalksPerVertex() {
        return this.modelMetadata.getWalksPerVertex();
    }

    public double getValidationFraction() {
        return this.modelMetadata.getValidationFraction();
    }

    public double getLoss() {
        return this.modelMetadata.getLoss().doubleValue();
    }

    public double getSeed() {
        return this.modelMetadata.getSeed().longValue();
    }
}
