package edu.usc.irds.agepredictor.spark.authorage;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import opennlp.tools.authorage.AgeClassifyME;
import opennlp.tools.util.featuregen.FeatureGenerator;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

/* loaded from: input_file:edu/usc/irds/agepredictor/spark/authorage/AgePredictEvaluator.class */
public class AgePredictEvaluator {
    public static void evaluate(SparkSession sparkSession, File file, File file2, File file3, String str) throws IOException {
        AgePredictModel readModel = AgePredictModel.readModel(file2);
        final AgeClassifyModelWrapper ageClassifyModelWrapper = file == null ? null : new AgeClassifyModelWrapper(file);
        JavaRDD cache = sparkSession.sparkContext().textFile(str, 8).toJavaRDD().cache();
        final AgeClassifyContextGeneratorWrapper context = readModel.getContext();
        JavaRDD map = cache.map(new Function<String, Row>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictEvaluator.1
            public Row call(String str2) throws IOException {
                String str3 = str2.split("\t", 2)[0];
                String[] strArr = AgeClassifyContextGeneratorWrapper.this.getTokenizer().tokenize(str2.split("\t", 2)[1]);
                String str4 = null;
                if (ageClassifyModelWrapper != null) {
                    AgeClassifyME classifier = ageClassifyModelWrapper.getClassifier();
                    str4 = classifier.getBestCategory(classifier.getProbabilities(strArr));
                }
                ArrayList arrayList = new ArrayList();
                for (FeatureGenerator featureGenerator : AgeClassifyContextGeneratorWrapper.this.getFeatureGenerators()) {
                    arrayList.addAll(featureGenerator.extractFeatures(strArr));
                }
                if (str4 != null) {
                    for (int i = 0; i < strArr.length / 18; i++) {
                        arrayList.add("cat=" + str4);
                    }
                }
                if (arrayList.size() <= 0) {
                    return null;
                }
                try {
                    return RowFactory.create(new Object[]{Integer.valueOf(Integer.valueOf(str3).intValue()), arrayList.toArray()});
                } catch (Exception e) {
                    return null;
                }
            }
        });
        JavaRDD cache2 = map.filter(new Function<Row, Boolean>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictEvaluator.2
            public Boolean call(Row row) {
                return Boolean.valueOf(row != null);
            }
        }).cache();
        map.unpersist();
        Dataset cache3 = sparkSession.createDataFrame(cache2, new StructType(new StructField[]{new StructField("value", DataTypes.IntegerType, false, Metadata.empty()), new StructField("context", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())})).cache();
        System.out.println("Vocab: " + readModel.getVocabulary());
        CountVectorizerModel outputCol = new CountVectorizerModel(readModel.getVocabulary()).setInputCol("context").setOutputCol("feature");
        Normalizer p = new Normalizer().setInputCol("feature").setOutputCol("norm").setP(1.0d);
        Dataset select = outputCol.transform(cache3).select("value", new String[]{"feature"});
        JavaRDD cache4 = p.transform(select).select("value", new String[]{"norm"}).javaRDD().cache();
        select.unpersist();
        JavaRDD map2 = cache4.map(new Function<Row, LabeledPoint>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictEvaluator.3
            public LabeledPoint call(Row row) {
                Integer valueOf = Integer.valueOf(row.getInt(0));
                SparseVector sparseVector = (SparseVector) row.get(1);
                return new LabeledPoint(valueOf.intValue(), Vectors.sparse(sparseVector.size(), sparseVector.indices(), sparseVector.values()));
            }
        });
        map2.cache();
        final LassoModel model = readModel.getModel();
        JavaRDD cache5 = map2.map(new Function<LabeledPoint, Tuple2<Double, Double>>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictEvaluator.4
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(model.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        }).cache();
        double doubleValue = new JavaDoubleRDD(cache5.map(new Function<Tuple2<Double, Double>, Object>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictEvaluator.5
            public Object call(Tuple2<Double, Double> tuple2) {
                return Double.valueOf(Math.abs(((Double) tuple2._1()).doubleValue() - ((Double) tuple2._2()).doubleValue()));
            }
        }).rdd()).mean().doubleValue();
        if (file3 != null) {
            Iterator localIterator = cache5.toLocalIterator();
            file3.createNewFile();
            FileWriter fileWriter = new FileWriter(file3);
            while (localIterator.hasNext()) {
                Tuple2 tuple2 = (Tuple2) localIterator.next();
                fileWriter.write(tuple2._1() + "," + tuple2._2() + "\n");
            }
            fileWriter.close();
        }
        System.out.println("Mean Absolute Error: " + doubleValue);
    }
}
