/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec;

import java.util.Collection;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;

public class Word2Vec
extends SequenceVectors<VocabWord> {
    protected SentenceIterator sentenceIter;
    protected TokenizerFactory tokenizerFactory;

    public void setSentenceIter(@NonNull SentenceIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        if (this.tokenizerFactory == null) {
            throw new IllegalStateException("Please call setTokenizerFactory() prior to setSentenceIter() call.");
        }
        SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).tokenizerFactory(this.tokenizerFactory).build();
        this.iterator = new AbstractSequenceIterator.Builder<VocabWord>(transformer).build();
    }

    public SentenceIterator getSentenceIter() {
        return this.sentenceIter;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.tokenizerFactory;
    }

    public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;
    }

    public static class Builder
    extends SequenceVectors.Builder<VocabWord> {
        protected SentenceIterator sentenceIterator;
        protected TokenizerFactory tokenizerFactory;

        public Builder() {
        }

        public Builder(@NonNull VectorsConfiguration configuration) {
            super(configuration);
            if (configuration == null) {
                throw new NullPointerException("configuration");
            }
        }

        public Builder iterate(@NonNull DocumentIterator iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build();
            return this;
        }

        public Builder iterate(@NonNull SentenceIterator iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.sentenceIterator = iterator;
            return this;
        }

        public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            if (tokenizerFactory == null) {
                throw new NullPointerException("tokenizerFactory");
            }
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        @Deprecated
        public Builder index(@NonNull InvertedIndex<VocabWord> index) {
            if (index == null) {
                throw new NullPointerException("index");
            }
            return this;
        }

        public Builder iterate(@NonNull SequenceIterator<VocabWord> iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            super.iterate(iterator);
            return this;
        }

        public Builder batchSize(int batchSize) {
            super.batchSize(batchSize);
            return this;
        }

        public Builder iterations(int iterations) {
            super.iterations(iterations);
            return this;
        }

        public Builder epochs(int numEpochs) {
            super.epochs(numEpochs);
            return this;
        }

        public Builder layerSize(int layerSize) {
            super.layerSize(layerSize);
            return this;
        }

        public Builder learningRate(double learningRate) {
            super.learningRate(learningRate);
            return this;
        }

        public Builder minWordFrequency(int minWordFrequency) {
            super.minWordFrequency(minWordFrequency);
            return this;
        }

        public Builder minLearningRate(double minLearningRate) {
            super.minLearningRate(minLearningRate);
            return this;
        }

        public Builder resetModel(boolean reallyReset) {
            super.resetModel(reallyReset);
            return this;
        }

        public Builder vocabCache(@NonNull VocabCache<VocabWord> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("vocabCache");
            }
            super.vocabCache(vocabCache);
            return this;
        }

        public Builder lookupTable(@NonNull WeightLookupTable<VocabWord> lookupTable) {
            if (lookupTable == null) {
                throw new NullPointerException("lookupTable");
            }
            super.lookupTable(lookupTable);
            return this;
        }

        public Builder sampling(double sampling) {
            super.sampling(sampling);
            return this;
        }

        public Builder useAdaGrad(boolean reallyUse) {
            super.useAdaGrad(reallyUse);
            return this;
        }

        public Builder negativeSample(double negative) {
            super.negativeSample(negative);
            return this;
        }

        public Builder stopWords(@NonNull List<String> stopList) {
            if (stopList == null) {
                throw new NullPointerException("stopList");
            }
            super.stopWords(stopList);
            return this;
        }

        public Builder trainElementsRepresentation(boolean trainElements) {
            throw new IllegalStateException("You can't change this option for Word2Vec");
        }

        public Builder trainSequencesRepresentation(boolean trainSequences) {
            throw new IllegalStateException("You can't change this option for Word2Vec");
        }

        public Builder stopWords(@NonNull Collection<VocabWord> stopList) {
            if (stopList == null) {
                throw new NullPointerException("stopList");
            }
            super.stopWords(stopList);
            return this;
        }

        public Builder windowSize(int windowSize) {
            super.windowSize(windowSize);
            return this;
        }

        public Builder seed(long randomSeed) {
            super.seed(randomSeed);
            return this;
        }

        public Builder workers(int numWorkers) {
            super.workers(numWorkers);
            return this;
        }

        public Word2Vec build() {
            this.presetTables();
            Word2Vec ret = new Word2Vec();
            if (this.tokenizerFactory == null) {
                this.tokenizerFactory = new DefaultTokenizerFactory();
            }
            if (this.sentenceIterator != null) {
                SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(this.sentenceIterator).tokenizerFactory(this.tokenizerFactory).build();
                this.iterator = new AbstractSequenceIterator.Builder<VocabWord>(transformer).build();
            }
            ret.numEpochs = this.numEpochs;
            ret.numIterations = this.iterations;
            ret.vocab = this.vocabCache;
            ret.minWordFrequency = this.minWordFrequency;
            ret.learningRate.set(this.learningRate);
            ret.minLearningRate = this.minLearningRate;
            ret.sampling = this.sampling;
            ret.negative = this.negative;
            ret.layerSize = this.layerSize;
            ret.batchSize = this.batchSize;
            ret.learningRateDecayWords = this.learningRateDecayWords;
            ret.window = this.window;
            ret.resetModel = this.resetModel;
            ret.useAdeGrad = this.useAdaGrad;
            ret.stopWords = this.stopWords;
            ret.workers = this.workers;
            ret.iterator = this.iterator;
            ret.lookupTable = this.lookupTable;
            ret.tokenizerFactory = this.tokenizerFactory;
            ret.elementsLearningAlgorithm = this.elementsLearningAlgorithm;
            ret.sequenceLearningAlgorithm = this.sequenceLearningAlgorithm;
            this.configuration.setLearningRate(this.learningRate);
            this.configuration.setLayersSize(this.layerSize);
            this.configuration.setHugeModelExpected(this.hugeModelExpected);
            this.configuration.setWindow(this.window);
            this.configuration.setMinWordFrequency(this.minWordFrequency);
            this.configuration.setIterations(this.iterations);
            this.configuration.setSeed(this.seed);
            this.configuration.setBatchSize(this.batchSize);
            this.configuration.setLearningRateDecayWords(this.learningRateDecayWords);
            this.configuration.setMinLearningRate(this.minLearningRate);
            this.configuration.setSampling(this.sampling);
            this.configuration.setUseAdaGrad(this.useAdaGrad);
            this.configuration.setNegative(this.negative);
            this.configuration.setEpochs(this.numEpochs);
            ret.configuration = this.configuration;
            ret.trainSequenceVectors = false;
            ret.trainElementsVectors = true;
            return ret;
        }
    }
}

