/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.embedding;

import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.List;

public class TrainableTextEmbedding
extends AbstractBlock
implements TextEmbedding {
    private static final byte VERSION = 1;
    private TrainableWordEmbedding trainableWordEmbedding;

    public TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding) {
        super((byte)1);
        this.trainableWordEmbedding = this.addChildBlock("trainableWordEmbedding", wordEmbedding);
    }

    @Override
    public int[] preprocessTextToEmbed(List<String> text) {
        int[] result = new int[text.size()];
        for (int i = 0; i < text.size(); ++i) {
            result[i] = this.trainableWordEmbedding.preprocessWordToEmbed(text.get(i));
        }
        return result;
    }

    @Override
    public NDArray embedText(NDArray textIndices) throws EmbeddingException {
        throw new UnsupportedOperationException("EmbedText operation is not supported by this class.");
    }

    @Override
    public List<String> unembedText(NDArray textEmbedding) {
        NDList split = textEmbedding.split(textEmbedding.getShape().get(0));
        ArrayList<String> result = new ArrayList<String>(split.size());
        for (NDArray token : split) {
            result.add(this.trainableWordEmbedding.unembedWord(token.get(0L)));
        }
        return result;
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        return this.trainableWordEmbedding.forward(parameterStore, inputs, training, params);
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.trainableWordEmbedding.initialize(manager, dataType, inputShapes);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return this.trainableWordEmbedding.getOutputShapes(manager, inputShapes);
    }
}

