/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.embedding;

import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import java.util.ArrayList;
import java.util.List;

public class SimpleTextEmbedding
implements TextEmbedding {
    private WordEmbedding wordEmbedding;

    public SimpleTextEmbedding(WordEmbedding wordEmbedding) {
        this.wordEmbedding = wordEmbedding;
    }

    @Override
    public int[] preprocessTextToEmbed(List<String> text) {
        int[] result = new int[text.size()];
        for (int i = 0; i < text.size(); ++i) {
            result[i] = this.wordEmbedding.preprocessWordToEmbed(text.get(i));
        }
        return result;
    }

    @Override
    public NDArray embedText(NDManager manager, int[] textIndices) throws EmbeddingException {
        NDList result = new NDList();
        for (int index : textIndices) {
            result.add(this.wordEmbedding.embedWord(manager, index));
        }
        return NDArrays.stack(result);
    }

    @Override
    public List<String> unembedText(NDArray textEmbedding) throws EmbeddingException {
        NDList split = textEmbedding.split(textEmbedding.getShape().get(0));
        ArrayList<String> result = new ArrayList<String>(split.size());
        for (NDArray token : split) {
            result.add(this.wordEmbedding.unembedWord(token.get(0L)));
        }
        return result;
    }
}

