package edu.stanford.nlp.neural;

import edu.stanford.nlp.coref.fastneural.FastNeuralCorefModel;
import edu.stanford.nlp.coref.neural.EmbeddingExtractor;
import edu.stanford.nlp.coref.neural.NeuralCorefModel;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.parser.dvparser.DVModel;
import edu.stanford.nlp.parser.dvparser.DVModelReranker;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.sentiment.RNNOptions;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.TwoDimensionalMap;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.function.Function;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/neural/ConvertModels.class */
public class ConvertModels {

    /* loaded from: input_file:edu/stanford/nlp/neural/ConvertModels$Model.class */
    public enum Model {
        SENTIMENT,
        DVPARSER,
        COREF,
        EMBEDDING,
        FASTCOREF
    }

    /* loaded from: input_file:edu/stanford/nlp/neural/ConvertModels$Stage.class */
    public enum Stage {
        OLD,
        NEW
    }

    public static <K1, K2, V, V2> TwoDimensionalMap<K1, K2, V2> transform2DMap(TwoDimensionalMap<K1, K2, V> twoDimensionalMap, Function<V, V2> function) {
        TwoDimensionalMap<K1, K2, V2> treeMap = TwoDimensionalMap.treeMap();
        treeMap.addAll(twoDimensionalMap, function);
        return treeMap;
    }

    public static List<List<Double>> fromMatrix(SimpleMatrix simpleMatrix) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < simpleMatrix.numRows(); i++) {
            arrayList.add(new ArrayList());
            for (int i2 = 0; i2 < simpleMatrix.numCols(); i2++) {
                ((List) arrayList.get(i)).add(Double.valueOf(simpleMatrix.get(i, i2)));
            }
        }
        return arrayList;
    }

    public static List<List<List<Double>>> fromTensor(SimpleTensor simpleTensor) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < simpleTensor.numSlices(); i++) {
            arrayList.add(fromMatrix(simpleTensor.getSlice(i)));
        }
        return arrayList;
    }

    public static SimpleMatrix toMatrix(List<List<Double>> list) {
        if (list.size() == 0) {
            throw new IllegalArgumentException("Input array with 0 rows");
        }
        if (list.get(0).size() == 0) {
            throw new IllegalArgumentException("Input array with 0 columns");
        }
        for (int i = 1; i < list.size(); i++) {
            if (list.get(i).size() != list.get(0).size()) {
                throw new IllegalArgumentException("Input array with uneven columns");
            }
        }
        SimpleMatrix simpleMatrix = new SimpleMatrix(list.size(), list.get(0).size());
        for (int i2 = 0; i2 < list.size(); i2++) {
            List<Double> list2 = list.get(i2);
            for (int i3 = 0; i3 < list2.size(); i3++) {
                simpleMatrix.set(i2, i3, list2.get(i3).doubleValue());
            }
        }
        return simpleMatrix;
    }

    public static SimpleTensor toTensor(List<List<List<Double>>> list) {
        int size = list.size();
        SimpleMatrix[] simpleMatrixArr = new SimpleMatrix[size];
        for (int i = 0; i < size; i++) {
            simpleMatrixArr[i] = toMatrix(list.get(i));
        }
        return new SimpleTensor(simpleMatrixArr);
    }

    public static <K, V, V2> Map<K, V2> transformMap(Map<K, V> map, Function<V, V2> function) {
        try {
            Map<K, V2> map2 = (Map) ErasureUtils.uncheckedCast(map.getClass().getConstructor(new Class[0]).newInstance(new Object[0]));
            for (Map.Entry<K, V> entry : map.entrySet()) {
                map2.put(entry.getKey(), function.apply(entry.getValue()));
            }
            return map2;
        } catch (IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    public static void writeSentiment(SentimentModel sentimentModel, ObjectOutputStream objectOutputStream) throws IOException {
        Function function = simpleMatrix -> {
            return fromMatrix(simpleMatrix);
        };
        objectOutputStream.writeObject(transform2DMap(sentimentModel.binaryTransform, function));
        objectOutputStream.writeObject(transform2DMap(sentimentModel.binaryTensors, simpleTensor -> {
            return fromTensor(simpleTensor);
        }));
        objectOutputStream.writeObject(transform2DMap(sentimentModel.binaryClassification, function));
        objectOutputStream.writeObject(transformMap(sentimentModel.unaryClassification, function));
        objectOutputStream.writeObject(transformMap(sentimentModel.wordVectors, function));
        objectOutputStream.writeObject(sentimentModel.op);
    }

    public static SentimentModel readSentiment(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        Function function = list -> {
            return toMatrix(list);
        };
        return new SentimentModel(transform2DMap((TwoDimensionalMap) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transform2DMap((TwoDimensionalMap) ErasureUtils.uncheckedCast(objectInputStream.readObject()), list2 -> {
            return toTensor(list2);
        }), transform2DMap((TwoDimensionalMap) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), (RNNOptions) ErasureUtils.uncheckedCast(objectInputStream.readObject()));
    }

    public static void writeParser(LexicalizedParser lexicalizedParser, DVModelReranker dVModelReranker, ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeObject(lexicalizedParser);
        Function function = simpleMatrix -> {
            return fromMatrix(simpleMatrix);
        };
        DVModel model = dVModelReranker.getModel();
        objectOutputStream.writeObject(transform2DMap(model.binaryTransform, function));
        objectOutputStream.writeObject(transformMap(model.unaryTransform, function));
        objectOutputStream.writeObject(transform2DMap(model.binaryScore, function));
        objectOutputStream.writeObject(transformMap(model.unaryScore, function));
        objectOutputStream.writeObject(transformMap(model.wordVectors, function));
    }

    public static LexicalizedParser readParser(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        LexicalizedParser lexicalizedParser = (LexicalizedParser) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        Function function = list -> {
            return toMatrix(list);
        };
        lexicalizedParser.reranker = new DVModelReranker(new DVModel(transform2DMap((TwoDimensionalMap) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transform2DMap((TwoDimensionalMap) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), lexicalizedParser.getOp()));
        return lexicalizedParser;
    }

    public static void writeEmbedding(Embedding embedding, ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeObject(transformMap(embedding.getWordVectors(), simpleMatrix -> {
            return fromMatrix(simpleMatrix);
        }));
    }

    public static Embedding readEmbedding(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        return new Embedding((Map<String, SimpleMatrix>) transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), list -> {
            return toMatrix(list);
        }));
    }

    public static void writeCoref(NeuralCorefModel neuralCorefModel, ObjectOutputStream objectOutputStream) throws IOException {
        Function function = simpleMatrix -> {
            return fromMatrix(simpleMatrix);
        };
        objectOutputStream.writeObject(fromMatrix(neuralCorefModel.getAntecedentMatrix()));
        objectOutputStream.writeObject(fromMatrix(neuralCorefModel.getAnaphorMatrix()));
        objectOutputStream.writeObject(fromMatrix(neuralCorefModel.getPairFeaturesMatrix()));
        objectOutputStream.writeObject(fromMatrix(neuralCorefModel.getPairwiseFirstLayerBias()));
        objectOutputStream.writeObject(CollectionUtils.transformAsList(neuralCorefModel.getAnaphoricityModel(), function));
        objectOutputStream.writeObject(CollectionUtils.transformAsList(neuralCorefModel.getPairwiseModel(), function));
        objectOutputStream.writeObject(transformMap(neuralCorefModel.getWordEmbeddings().getWordVectors(), function));
    }

    public static NeuralCorefModel readCoref(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        Function function = list -> {
            return toMatrix(list);
        };
        return new NeuralCorefModel(toMatrix((List) ErasureUtils.uncheckedCast(objectInputStream.readObject())), toMatrix((List) ErasureUtils.uncheckedCast(objectInputStream.readObject())), toMatrix((List) ErasureUtils.uncheckedCast(objectInputStream.readObject())), toMatrix((List) ErasureUtils.uncheckedCast(objectInputStream.readObject())), CollectionUtils.transformAsList((Collection) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), CollectionUtils.transformAsList((Collection) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function), new Embedding((Map<String, SimpleMatrix>) transformMap((Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), function)));
    }

    public static void writeFastCoref(FastNeuralCorefModel fastNeuralCorefModel, ObjectOutputStream objectOutputStream) throws IOException {
        Function function = simpleMatrix -> {
            return fromMatrix(simpleMatrix);
        };
        EmbeddingExtractor embeddingExtractor = fastNeuralCorefModel.getEmbeddingExtractor();
        objectOutputStream.writeObject(Boolean.valueOf(embeddingExtractor.isConll()));
        Embedding staticWordEmbeddings = embeddingExtractor.getStaticWordEmbeddings();
        if (staticWordEmbeddings == null) {
            objectOutputStream.writeObject(false);
        } else {
            objectOutputStream.writeObject(true);
            writeEmbedding(staticWordEmbeddings, objectOutputStream);
        }
        writeEmbedding(embeddingExtractor.getTunedWordEmbeddings(), objectOutputStream);
        objectOutputStream.writeObject(embeddingExtractor.getNAEmbedding());
        objectOutputStream.writeObject(fastNeuralCorefModel.getPairFeatureIds());
        objectOutputStream.writeObject(fastNeuralCorefModel.getMentionFeatureIds());
        objectOutputStream.writeObject(CollectionUtils.transformAsList(fastNeuralCorefModel.getAllWeights(), function));
    }

    public static FastNeuralCorefModel readFastCoref(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        return new FastNeuralCorefModel(new EmbeddingExtractor(((Boolean) ErasureUtils.uncheckedCast(objectInputStream.readObject())).booleanValue(), ((Boolean) ErasureUtils.uncheckedCast(objectInputStream.readObject())).booleanValue() ? readEmbedding(objectInputStream) : null, readEmbedding(objectInputStream), (String) ErasureUtils.uncheckedCast(objectInputStream.readObject())), (Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), (Map) ErasureUtils.uncheckedCast(objectInputStream.readObject()), CollectionUtils.transformAsList((Collection) ErasureUtils.uncheckedCast(objectInputStream.readObject()), list -> {
            return toMatrix(list);
        }));
    }

    public static void main(String[] strArr) throws IOException, ClassNotFoundException, InstantiationException, NoSuchMethodException {
        Properties argsToProperties = StringUtils.argsToProperties(strArr);
        try {
            Stage valueOf = Stage.valueOf(argsToProperties.getProperty("stage").toUpperCase());
            try {
                Model valueOf2 = Model.valueOf(argsToProperties.getProperty("model").toUpperCase());
                if (!argsToProperties.containsKey("input")) {
                    throw new IllegalArgumentException("Please specify -input");
                }
                if (!argsToProperties.containsKey("output")) {
                    throw new IllegalArgumentException("Please specify -output");
                }
                String property = argsToProperties.getProperty("input");
                String property2 = argsToProperties.getProperty("output");
                if (valueOf2 == Model.SENTIMENT) {
                    if (valueOf == Stage.OLD) {
                        SentimentModel loadSerialized = SentimentModel.loadSerialized(property);
                        ObjectOutputStream writeStreamFromString = IOUtils.writeStreamFromString(property2);
                        writeSentiment(loadSerialized, writeStreamFromString);
                        writeStreamFromString.close();
                        return;
                    }
                    ObjectInputStream readStreamFromString = IOUtils.readStreamFromString(property);
                    SentimentModel readSentiment = readSentiment(readStreamFromString);
                    readStreamFromString.close();
                    readSentiment.saveSerialized(property2);
                    return;
                }
                if (valueOf2 == Model.DVPARSER) {
                    if (valueOf != Stage.OLD) {
                        ObjectInputStream readStreamFromString2 = IOUtils.readStreamFromString(property);
                        LexicalizedParser readParser = readParser(readStreamFromString2);
                        readStreamFromString2.close();
                        readParser.saveParserToSerialized(property2);
                        return;
                    }
                    LexicalizedParser loadModel = LexicalizedParser.loadModel(property, new String[0]);
                    if (loadModel.reranker == null) {
                        System.out.println("Nothing to do for " + property);
                        return;
                    }
                    DVModelReranker dVModelReranker = (DVModelReranker) loadModel.reranker;
                    loadModel.reranker = null;
                    ObjectOutputStream writeStreamFromString2 = IOUtils.writeStreamFromString(property2);
                    writeParser(loadModel, dVModelReranker, writeStreamFromString2);
                    writeStreamFromString2.close();
                    return;
                }
                if (valueOf2 == Model.EMBEDDING) {
                    if (valueOf == Stage.OLD) {
                        Embedding embedding = (Embedding) ErasureUtils.uncheckedCast(IOUtils.readObjectFromURLOrClasspathOrFileSystem(property));
                        ObjectOutputStream writeStreamFromString3 = IOUtils.writeStreamFromString(property2);
                        writeEmbedding(embedding, writeStreamFromString3);
                        writeStreamFromString3.close();
                        return;
                    }
                    ObjectInputStream readStreamFromString3 = IOUtils.readStreamFromString(property);
                    Embedding readEmbedding = readEmbedding(readStreamFromString3);
                    readStreamFromString3.close();
                    IOUtils.writeObjectToFile(readEmbedding, property2);
                    return;
                }
                if (valueOf2 == Model.COREF) {
                    if (valueOf == Stage.OLD) {
                        NeuralCorefModel neuralCorefModel = (NeuralCorefModel) ErasureUtils.uncheckedCast(IOUtils.readObjectFromURLOrClasspathOrFileSystem(property));
                        ObjectOutputStream writeStreamFromString4 = IOUtils.writeStreamFromString(property2);
                        writeCoref(neuralCorefModel, writeStreamFromString4);
                        writeStreamFromString4.close();
                        return;
                    }
                    ObjectInputStream readStreamFromString4 = IOUtils.readStreamFromString(property);
                    NeuralCorefModel readCoref = readCoref(readStreamFromString4);
                    readStreamFromString4.close();
                    IOUtils.writeObjectToFile(readCoref, property2);
                    return;
                }
                if (valueOf2 != Model.FASTCOREF) {
                    throw new IllegalArgumentException("Unknown model type " + valueOf2);
                }
                if (valueOf == Stage.OLD) {
                    FastNeuralCorefModel fastNeuralCorefModel = (FastNeuralCorefModel) ErasureUtils.uncheckedCast(IOUtils.readObjectFromURLOrClasspathOrFileSystem(property));
                    ObjectOutputStream writeStreamFromString5 = IOUtils.writeStreamFromString(property2);
                    writeFastCoref(fastNeuralCorefModel, writeStreamFromString5);
                    writeStreamFromString5.close();
                    return;
                }
                ObjectInputStream readStreamFromString5 = IOUtils.readStreamFromString(property);
                FastNeuralCorefModel readFastCoref = readFastCoref(readStreamFromString5);
                readStreamFromString5.close();
                IOUtils.writeObjectToFile(readFastCoref, property2);
            } catch (IllegalArgumentException | NullPointerException e) {
                throw new IllegalArgumentException("Please specify -model, either SENTIMENT, DVPARSER, EMBEDDING, COREF, FASTCOREF");
            }
        } catch (IllegalArgumentException | NullPointerException e2) {
            throw new IllegalArgumentException("Please specify -stage, either OLD or NEW");
        }
    }
}
