/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.applications;

import edu.uci.jforests.applications.ClassificationApp;
import edu.uci.jforests.config.RankingTrainingConfig;
import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.dataset.RankingDataset;
import edu.uci.jforests.dataset.RankingDatasetLoader;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.MAPEval;
import edu.uci.jforests.eval.ranking.NDCGEval;
import edu.uci.jforests.eval.ranking.TRiskAwareFAROEval;
import edu.uci.jforests.eval.ranking.TRiskAwareSAROEval;
import edu.uci.jforests.eval.ranking.URiskAwareEval;
import edu.uci.jforests.learning.LearningModule;
import edu.uci.jforests.learning.boosting.LambdaMART;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.Util;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class RankingApp
extends ClassificationApp {
    protected int maxDocsPerQuery;

    protected void init() throws Exception {
        this.maxDocsPerQuery = ((RankingDataset)this.trainSet.dataset).maxDocsPerQuery;
        if (this.validSet != null) {
            this.maxDocsPerQuery = Math.max(this.maxDocsPerQuery, ((RankingDataset)this.validSet.dataset).maxDocsPerQuery);
        }
        NDCGEval.initialize(this.maxDocsPerQuery);
        super.init();
        String trainQidsFilename = ((RankingTrainingConfig)this.trainingConfig).trainQidsFilename;
        if (trainQidsFilename != null) {
            List<Integer> trainQids = Util.loadIntegersFromFile(trainQidsFilename);
            ArrayList<Integer> validQids = new ArrayList<Integer>();
            int validSize = (int)((double)trainQids.size() * 0.4);
            for (int i = 0; i < validSize; ++i) {
                int idx = this.rnd.nextInt(trainQids.size());
                int qid = trainQids.get(idx);
                trainQids.remove(idx);
                validQids.add(qid);
            }
            Collections.sort(validQids);
            RankingSample newTrainSet = ((RankingSample)this.trainSet).getFilteredSubSample(trainQids);
            this.validSet = ((RankingSample)this.trainSet).getFilteredSubSample(validQids);
            this.trainSet = newTrainSet;
        }
    }

    protected void loadConfig() {
        this.trainingConfig = new RankingTrainingConfig();
        this.trainingConfig.init(this.configHolder);
    }

    protected Dataset newDataset() {
        return new RankingDataset();
    }

    public void initDataset(Dataset dataset) throws Exception {
        if (dataset == null || !dataset.needsInitialization) {
            return;
        }
        RankingDataset rankingDataset = (RankingDataset)dataset;
        int[][] labelCounts = NDCGEval.getLabelCountsForQueries(rankingDataset.targets, rankingDataset.queryBoundaries);
        rankingDataset.maxDCG = NDCGEval.getMaxDCGForAllQueriesUptoTruncation(rankingDataset.targets, rankingDataset.queryBoundaries, 10, labelCounts);
    }

    public void loadDataset(InputStream in, Dataset dataset) throws Exception {
        RankingDatasetLoader.load(in, (RankingDataset)dataset);
    }

    protected LearningModule getLearningModule(String name) throws Exception {
        int maxTrainInstances = this.getMaxTrainInstances();
        if (name.equals("LambdaMART")) {
            LambdaMART learner = new LambdaMART();
            learner.init(this.configHolder, (RankingDataset)this.trainDataset, maxTrainInstances, this.validDataset != null ? this.validDataset.numInstances : this.trainDataset.numInstances, this.evaluationMetric);
            return learner;
        }
        return super.getLearningModule(name);
    }

    protected EvaluationMetric getEvaluationMetric(String name) throws Exception {
        if (name.startsWith("URiskAwareEval:")) {
            String[] parts = name.split(":");
            double ALPHA = Double.parseDouble(parts[1]);
            String parentMeasure = parts[2];
            return new URiskAwareEval(this.getEvaluationMetric(parentMeasure), ALPHA);
        }
        if (name.startsWith("TRiskAwareEvalSARO:") || name.startsWith("TRiskAwareSAROEval:")) {
            String[] parts = name.split(":");
            double ALPHA = Double.parseDouble(parts[1]);
            String parentMeasure = parts[2];
            return new TRiskAwareSAROEval(this.getEvaluationMetric(parentMeasure), ALPHA);
        }
        if (name.startsWith("TRiskAwareEvalFARO:") || name.startsWith("TRiskAwareFAROEval:")) {
            String[] parts = name.split(":");
            double ALPHA = Double.parseDouble(parts[1]);
            String parentMeasure = parts[2];
            return new TRiskAwareFAROEval(this.getEvaluationMetric(parentMeasure), ALPHA);
        }
        if (name.equals("NDCG")) {
            return new NDCGEval(this.maxDocsPerQuery, ((RankingTrainingConfig)this.trainingConfig).validNDCGTruncation);
        }
        if (name.equals("MAP")) {
            return new MAPEval(this.maxDocsPerQuery);
        }
        return super.getEvaluationMetric(name);
    }

    protected Sample createSample(Dataset dataset, boolean trainSample) {
        RankingSample sample = new RankingSample((RankingDataset)dataset);
        RankingTrainingConfig config = (RankingTrainingConfig)this.trainingConfig;
        if (trainSample && config.augmentationDocSamplingEnabled) {
            return sample.getAugmentedSampleWithDocSampling(config.augmentationDocSamplingTimes, config.augmentationDocSamplingRate, this.rnd);
        }
        return sample;
    }

    protected int getMaxTrainInstances() {
        RankingTrainingConfig config = (RankingTrainingConfig)this.trainingConfig;
        if (config.augmentationDocSamplingEnabled) {
            return this.trainDataset.numInstances * (config.augmentationDocSamplingTimes + 1);
        }
        return this.trainDataset.numInstances;
    }
}

