package edu.usc.irds.agepredictor.spark.authorage;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import opennlp.tools.util.TrainingParameters;
import org.apache.commons.io.FileUtils;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LassoWithSGD;
import org.apache.spark.mllib.stat.Statistics;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

/* loaded from: input_file:edu/usc/irds/agepredictor/spark/authorage/AgePredictSGDTrainer.class */
public class AgePredictSGDTrainer {
    public static final String CUTOFF_PARAM = "Cutoff";
    public static final int CUTOFF_DEFAULT = 5;
    public static final String ITERATIONS_PARAM = "Iterations";
    public static final int ITERATIONS_DEFAULT = 100;
    public static final String STEPSIZE_PARAM = "StepSize";
    public static final double STEPSIZE_DEFAULT = 1.0d;
    public static final String REG_PARAM = "Regularization";
    public static final double REG_DEFAULT = 0.1d;

    public static void generateEvents(SparkSession sparkSession, String str, String str2, String str3, String str4) throws IOException {
        JavaRDD repartition = sparkSession.sparkContext().textFile(str, 48).toJavaRDD().cache().map(new CreateEvents(new AgeClassifyContextGeneratorWrapper(str2, str3))).cache().filter(new Function<EventWrapper, Boolean>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.1
            public Boolean call(EventWrapper eventWrapper) {
                if (eventWrapper != null) {
                    return Boolean.valueOf(eventWrapper.getValue() != null);
                }
                return false;
            }
        }).repartition(8);
        File file = new File(str4);
        if (file.exists()) {
            FileUtils.cleanDirectory(file);
            FileUtils.forceDelete(file);
        }
        repartition.saveAsTextFile(str4);
    }

    private static int getCutoff(Map<String, String> map) {
        String str = map.get(CUTOFF_PARAM);
        if (str != null) {
            return Integer.parseInt(str);
        }
        return 5;
    }

    private static int getIterations(Map<String, String> map) {
        String str = map.get(ITERATIONS_PARAM);
        if (str != null) {
            return Integer.parseInt(str);
        }
        return 100;
    }

    private static double getStepSize(Map<String, String> map) {
        String str = map.get(STEPSIZE_PARAM);
        if (str != null) {
            return Double.parseDouble(str);
        }
        return 1.0d;
    }

    private static double getReg(Map<String, String> map) {
        String str = map.get(REG_PARAM);
        if (str != null) {
            return Double.parseDouble(str);
        }
        return 0.1d;
    }

    public static AgePredictModel createModel(String str, SparkSession sparkSession, String str2, AgeClassifyContextGeneratorWrapper ageClassifyContextGeneratorWrapper, TrainingParameters trainingParameters) throws IOException {
        Map settings = trainingParameters.getSettings();
        int cutoff = getCutoff(settings);
        int iterations = getIterations(settings);
        JavaRDD cache = sparkSession.sparkContext().textFile(str2, 24).toJavaRDD().cache().map(new Function<String, Row>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.2
            public Row call(String str3) {
                if (str3 == null) {
                    return null;
                }
                String[] split = str3.split(",");
                if (split.length != 3) {
                    return null;
                }
                try {
                    if (split[0] == "-1") {
                        return null;
                    }
                    Integer valueOf = Integer.valueOf(Integer.parseInt(split[0]));
                    String[] split2 = split[2].split(" ");
                    ArrayList arrayList = new ArrayList(Arrays.asList(split2));
                    for (int i = 0; i < split2.length / 18; i++) {
                        arrayList.add("cat=" + split[1]);
                    }
                    return RowFactory.create(new Object[]{valueOf, arrayList.toArray()});
                } catch (Exception e) {
                    return null;
                }
            }
        }).cache();
        JavaRDD cache2 = cache.filter(new Function<Row, Boolean>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.3
            public Boolean call(Row row) {
                return Boolean.valueOf(row != null);
            }
        }).cache();
        cache.unpersist();
        Dataset cache3 = sparkSession.createDataFrame(cache2, new StructType(new StructField[]{new StructField("value", DataTypes.IntegerType, false, Metadata.empty()), new StructField("context", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())})).cache();
        CountVectorizerModel fit = new CountVectorizer().setInputCol("context").setOutputCol("feature").setMinDF(cutoff).fit(cache3);
        Normalizer p = new Normalizer().setInputCol("feature").setOutputCol("normFeature").setP(1.0d);
        Dataset select = fit.transform(cache3).select("value", new String[]{"feature"});
        Dataset select2 = p.transform(select).select("value", new String[]{"normFeature"});
        JavaRDD cache4 = select2.javaRDD().cache();
        select.unpersist();
        select2.unpersist();
        JavaRDD cache5 = cache4.map(new Function<Row, LabeledPoint>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.4
            public LabeledPoint call(Row row) {
                Integer valueOf = Integer.valueOf(row.getInt(0));
                SparseVector sparseVector = (SparseVector) row.get(1);
                return new LabeledPoint(valueOf.intValue(), Vectors.sparse(sparseVector.size(), sparseVector.indices(), sparseVector.values()));
            }
        }).cache();
        double stepSize = getStepSize(settings);
        double reg = getReg(settings);
        LassoWithSGD intercept = new LassoWithSGD().setIntercept(true);
        intercept.optimizer().setNumIterations(iterations).setStepSize(stepSize).setRegParam(reg);
        final LassoModel run = intercept.run(JavaRDD.toRDD(cache5));
        System.out.println("Coefficients: " + Arrays.toString(run.weights().toArray()));
        System.out.println("Intercept: " + run.intercept());
        JavaRDD cache6 = cache5.map(new Function<LabeledPoint, Tuple2<Double, Double>>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.5
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                double predict = run.predict(labeledPoint.features());
                System.out.println(predict + "," + labeledPoint.label());
                return new Tuple2<>(Double.valueOf(predict), Double.valueOf(labeledPoint.label()));
            }
        }).cache();
        double doubleValue = new JavaDoubleRDD(cache6.map(new Function<Tuple2<Double, Double>, Object>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.6
            public Object call(Tuple2<Double, Double> tuple2) {
                return Double.valueOf(Math.abs(((Double) tuple2._1()).doubleValue() - ((Double) tuple2._2()).doubleValue()));
            }
        }).rdd()).mean().doubleValue();
        Matrix corr = Statistics.corr(cache6.map(new Function<Tuple2<Double, Double>, Vector>() { // from class: edu.usc.irds.agepredictor.spark.authorage.AgePredictSGDTrainer.7
            public Vector call(Tuple2<Double, Double> tuple2) {
                return Vectors.dense(((Double) tuple2._1()).doubleValue(), new double[]{((Double) tuple2._2()).doubleValue()});
            }
        }).rdd(), "pearson");
        System.out.println("Training Mean Absolute Error: " + doubleValue);
        System.out.println("Correlation:\n" + corr.toString());
        new HashMap();
        return new AgePredictModel(str, run, fit.vocabulary(), ageClassifyContextGeneratorWrapper);
    }
}
