package edu.uci.jforests.applications;

import edu.uci.jforests.config.RankingTrainingConfig;
import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.dataset.RankingDataset;
import edu.uci.jforests.dataset.RankingDatasetLoader;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.MAPEval;
import edu.uci.jforests.eval.ranking.NDCGEval;
import edu.uci.jforests.eval.ranking.TRiskAwareFAROEval;
import edu.uci.jforests.eval.ranking.TRiskAwareSAROEval;
import edu.uci.jforests.eval.ranking.URiskAwareEval;
import edu.uci.jforests.learning.LearningModule;
import edu.uci.jforests.learning.boosting.LambdaMART;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.Util;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:edu/uci/jforests/applications/RankingApp.class */
public class RankingApp extends ClassificationApp {
    protected int maxDocsPerQuery;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.uci.jforests.applications.ClassificationApp
    public void init() throws Exception {
        this.maxDocsPerQuery = ((RankingDataset) this.trainSet.dataset).maxDocsPerQuery;
        if (this.validSet != null) {
            this.maxDocsPerQuery = Math.max(this.maxDocsPerQuery, ((RankingDataset) this.validSet.dataset).maxDocsPerQuery);
        }
        NDCGEval.initialize(this.maxDocsPerQuery);
        super.init();
        String str = ((RankingTrainingConfig) this.trainingConfig).trainQidsFilename;
        if (str != null) {
            List<Integer> loadIntegersFromFile = Util.loadIntegersFromFile(str);
            ArrayList arrayList = new ArrayList();
            int size = (int) (loadIntegersFromFile.size() * 0.4d);
            for (int i = 0; i < size; i++) {
                int nextInt = this.rnd.nextInt(loadIntegersFromFile.size());
                int intValue = loadIntegersFromFile.get(nextInt).intValue();
                loadIntegersFromFile.remove(nextInt);
                arrayList.add(Integer.valueOf(intValue));
            }
            Collections.sort(arrayList);
            RankingSample filteredSubSample = ((RankingSample) this.trainSet).getFilteredSubSample(loadIntegersFromFile);
            this.validSet = ((RankingSample) this.trainSet).getFilteredSubSample(arrayList);
            this.trainSet = filteredSubSample;
        }
    }

    @Override // edu.uci.jforests.applications.ClassificationApp
    protected void loadConfig() {
        this.trainingConfig = new RankingTrainingConfig();
        this.trainingConfig.init(this.configHolder);
    }

    @Override // edu.uci.jforests.applications.ClassificationApp
    protected Dataset newDataset() {
        return new RankingDataset();
    }

    @Override // edu.uci.jforests.applications.ClassificationApp
    public void initDataset(Dataset dataset) throws Exception {
        if (dataset == null || !dataset.needsInitialization) {
            return;
        }
        RankingDataset rankingDataset = (RankingDataset) dataset;
        rankingDataset.maxDCG = NDCGEval.getMaxDCGForAllQueriesUptoTruncation(rankingDataset.targets, rankingDataset.queryBoundaries, 10, NDCGEval.getLabelCountsForQueries(rankingDataset.targets, rankingDataset.queryBoundaries));
    }

    @Override // edu.uci.jforests.applications.ClassificationApp
    public void loadDataset(InputStream inputStream, Dataset dataset) throws Exception {
        RankingDatasetLoader.load(inputStream, (RankingDataset) dataset);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.uci.jforests.applications.ClassificationApp
    public LearningModule getLearningModule(String str) throws Exception {
        int maxTrainInstances = getMaxTrainInstances();
        if (!str.equals("LambdaMART")) {
            return super.getLearningModule(str);
        }
        LambdaMART lambdaMART = new LambdaMART();
        lambdaMART.init(this.configHolder, (RankingDataset) this.trainDataset, maxTrainInstances, this.validDataset != null ? this.validDataset.numInstances : this.trainDataset.numInstances, this.evaluationMetric);
        return lambdaMART;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.uci.jforests.applications.ClassificationApp
    public EvaluationMetric getEvaluationMetric(String str) throws Exception {
        if (str.startsWith("URiskAwareEval:")) {
            String[] split = str.split(":");
            return new URiskAwareEval(getEvaluationMetric(split[2]), Double.parseDouble(split[1]));
        }
        if (str.startsWith("TRiskAwareEvalSARO:") || str.startsWith("TRiskAwareSAROEval:")) {
            String[] split2 = str.split(":");
            return new TRiskAwareSAROEval(getEvaluationMetric(split2[2]), Double.parseDouble(split2[1]));
        }
        if (!str.startsWith("TRiskAwareEvalFARO:") && !str.startsWith("TRiskAwareFAROEval:")) {
            return str.equals("NDCG") ? new NDCGEval(this.maxDocsPerQuery, ((RankingTrainingConfig) this.trainingConfig).validNDCGTruncation) : str.equals("MAP") ? new MAPEval(this.maxDocsPerQuery) : super.getEvaluationMetric(str);
        }
        String[] split3 = str.split(":");
        return new TRiskAwareFAROEval(getEvaluationMetric(split3[2]), Double.parseDouble(split3[1]));
    }

    @Override // edu.uci.jforests.applications.ClassificationApp
    protected Sample createSample(Dataset dataset, boolean z) {
        RankingSample rankingSample = new RankingSample((RankingDataset) dataset);
        RankingTrainingConfig rankingTrainingConfig = (RankingTrainingConfig) this.trainingConfig;
        return (z && rankingTrainingConfig.augmentationDocSamplingEnabled) ? rankingSample.getAugmentedSampleWithDocSampling(rankingTrainingConfig.augmentationDocSamplingTimes, rankingTrainingConfig.augmentationDocSamplingRate, this.rnd) : rankingSample;
    }

    @Override // edu.uci.jforests.applications.ClassificationApp
    protected int getMaxTrainInstances() {
        RankingTrainingConfig rankingTrainingConfig = (RankingTrainingConfig) this.trainingConfig;
        return rankingTrainingConfig.augmentationDocSamplingEnabled ? this.trainDataset.numInstances * (rankingTrainingConfig.augmentationDocSamplingTimes + 1) : this.trainDataset.numInstances;
    }
}
