package edu.uci.jforests.learning.boosting;

import edu.uci.jforests.dataset.RankingDataset;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.NDCGEval;
import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.learning.trees.LeafInstances;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.ConfigHolder;
import edu.uci.jforests.util.ScoreBasedComparator;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforests.util.concurrency.TaskCollection;
import edu.uci.jforests.util.concurrency.TaskItem;
import java.util.Arrays;

/* loaded from: input_file:edu/uci/jforests/learning/boosting/LambdaMART.class */
public class LambdaMART extends GradientBoosting {
    private TaskCollection<LambdaWorker> workers;
    private RankingEvaluationMetric.SwapScorer swapScorer;
    private double sigmoidParam;
    private double[] sigmoidCache;
    private double minScore;
    private double maxScore;
    private double sigmoidBinWidth;
    protected double[] denomWeights;
    private int[] subLearnerSampleIndicesInTrainSet;

    /* loaded from: input_file:edu/uci/jforests/learning/boosting/LambdaMART$LambdaWorker.class */
    private class LambdaWorker extends TaskItem {
        private int[] permutation;
        private int beginIdx;
        private int endIdx;
        private ScoreBasedComparator comparator = new ScoreBasedComparator();

        public LambdaWorker(int i) {
            this.permutation = new int[i];
        }

        public void init(int i, int i2) {
            this.beginIdx = i;
            this.endIdx = i2;
            this.comparator.labels = LambdaMART.this.curTrainSet.targets;
        }

        @Override // java.lang.Runnable
        public void run() {
            RankingSample rankingSample = (RankingSample) LambdaMART.this.curTrainSet;
            double[] dArr = rankingSample.targets;
            this.comparator.scores = LambdaMART.this.trainPredictions;
            try {
                for (int i = this.beginIdx; i < this.endIdx; i++) {
                    int i2 = rankingSample.queryBoundaries[i];
                    int i3 = rankingSample.queryBoundaries[i + 1] - i2;
                    this.comparator.offset = i2;
                    for (int i4 = 0; i4 < i3; i4++) {
                        this.permutation[i4] = i4;
                    }
                    ArraysUtil.insertionSort(this.permutation, i3, this.comparator);
                    for (int i5 = 0; i5 < i3; i5++) {
                        int i6 = this.permutation[i5];
                        if (dArr[i2 + i6] > 0.0d) {
                            for (int i7 = 0; i7 < i3; i7++) {
                                if (i5 != i7) {
                                    int i8 = this.permutation[i7];
                                    if (dArr[i2 + i6] > dArr[i2 + i8]) {
                                        double d = LambdaMART.this.trainPredictions[i2 + i6] - LambdaMART.this.trainPredictions[i2 + i8];
                                        double d2 = d <= LambdaMART.this.minScore ? LambdaMART.this.sigmoidCache[0] : d >= LambdaMART.this.maxScore ? LambdaMART.this.sigmoidCache[LambdaMART.this.sigmoidCache.length - 1] : LambdaMART.this.sigmoidCache[(int) ((d - LambdaMART.this.minScore) / LambdaMART.this.sigmoidBinWidth)];
                                        double abs = Math.abs(LambdaMART.this.swapScorer.getDelta(rankingSample.queryIndices[i], i2 + i6, i5, i2 + i8, i7));
                                        double[] dArr2 = LambdaMART.this.residuals;
                                        int i9 = i2 + i6;
                                        dArr2[i9] = dArr2[i9] + (d2 * abs);
                                        double[] dArr3 = LambdaMART.this.residuals;
                                        int i10 = i2 + i8;
                                        dArr3[i10] = dArr3[i10] - (d2 * abs);
                                        double d3 = d2 * (1.0d - d2) * abs;
                                        double[] dArr4 = LambdaMART.this.denomWeights;
                                        int i11 = i2 + i6;
                                        dArr4[i11] = dArr4[i11] + d3;
                                        double[] dArr5 = LambdaMART.this.denomWeights;
                                        int i12 = i2 + i8;
                                        dArr5[i12] = dArr5[i12] + d3;
                                    }
                                }
                            }
                        }
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public LambdaMART() {
        super("LambdaMART");
    }

    public void init(ConfigHolder configHolder, RankingDataset rankingDataset, int i, int i2, EvaluationMetric evaluationMetric) throws Exception {
        super.init(configHolder, i, i2, evaluationMetric);
        LambdaMARTConfig lambdaMARTConfig = (LambdaMARTConfig) configHolder.getConfig(LambdaMARTConfig.class);
        GradientBoostingConfig gradientBoostingConfig = (GradientBoostingConfig) configHolder.getConfig(GradientBoostingConfig.class);
        this.swapScorer = ((RankingEvaluationMetric) evaluationMetric).getSwapScorer(rankingDataset.targets, rankingDataset.queryBoundaries, lambdaMARTConfig.maxDCGTruncation, NDCGEval.getLabelCountsForQueries(rankingDataset.targets, rankingDataset.queryBoundaries));
        this.sigmoidParam = gradientBoostingConfig.learningRate;
        initSigmoidCache(lambdaMARTConfig.sigmoidBins, lambdaMARTConfig.costFunction);
        this.workers = new TaskCollection<>();
        int maximumPoolSize = BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize();
        for (int i3 = 0; i3 < maximumPoolSize; i3++) {
            this.workers.addTask(new LambdaWorker(rankingDataset.maxDocsPerQuery));
        }
        this.denomWeights = new double[i];
        this.subLearnerSampleIndicesInTrainSet = new int[i];
    }

    private void initSigmoidCache(int i, String str) throws Exception {
        this.minScore = (-50.0d) / this.sigmoidParam;
        this.maxScore = -this.minScore;
        this.sigmoidCache = new double[i];
        this.sigmoidBinWidth = (this.maxScore - this.minScore) / i;
        if (str.equals("cross-entropy")) {
            for (int i2 = 0; i2 < i; i2++) {
                double d = this.minScore + (i2 * this.sigmoidBinWidth);
                if (d > 0.0d) {
                    this.sigmoidCache[i2] = 1.0d - (1.0d / (1.0d + Math.exp((-this.sigmoidParam) * d)));
                } else {
                    this.sigmoidCache[i2] = 1.0d / (1.0d + Math.exp(this.sigmoidParam * d));
                }
            }
            return;
        }
        if (!str.equals("fidelity")) {
            throw new Exception("Unknown cost function: " + str);
        }
        for (int i3 = 0; i3 < i; i3++) {
            double d2 = this.minScore + (i3 * this.sigmoidBinWidth);
            if (d2 > 0.0d) {
                double exp = Math.exp((-2.0d) * this.sigmoidParam * d2);
                this.sigmoidCache[i3] = ((-this.sigmoidParam) / 2.0d) * Math.sqrt(exp / Math.pow(1.0d + exp, 3.0d));
            } else {
                double exp2 = Math.exp(this.sigmoidParam * d2);
                this.sigmoidCache[i3] = ((-this.sigmoidParam) / 2.0d) * Math.sqrt(exp2 / Math.pow(1.0d + exp2, 3.0d));
            }
        }
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected void preprocess() {
        Arrays.fill(this.trainPredictions, 0, this.curTrainSet.size, 0.0d);
        if (this.curValidSet != null) {
            Arrays.fill(this.validPredictions, 0, this.curValidSet.size, 0.0d);
        }
        double[] dArr = null;
        try {
            dArr = ((RankingSample) this.curTrainSet).evaluateByQuery(RankingEvaluationMetric.computeNaturalOrderScores(this.curTrainSet.size, this.swapScorer.getQueryBoundaries()), (RankingEvaluationMetric) ((RankingEvaluationMetric) this.evaluationMetric).getParentMetric());
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.swapScorer.setCurrentIterationEvaluation(0, dArr);
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected void postProcessScores() {
    }

    protected double getAdjustedOutput(LeafInstances leafInstances) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = leafInstances.begin; i < leafInstances.end; i++) {
            int i2 = this.subLearnerSampleIndicesInTrainSet[leafInstances.indices[i]];
            d += this.residuals[i2];
            d2 += this.denomWeights[i2];
        }
        return (d + 1.4E-45d) / (d2 + 1.4E-45d);
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
        LeafInstances leafInstances = new LeafInstances();
        for (int i = 0; i < tree.numLeaves; i++) {
            treeLeafInstances.loadLeafInstances(i, leafInstances);
            ((RegressionTree) tree).setLeafOutput(i, getAdjustedOutput(leafInstances));
        }
    }

    protected void setSubLearnerSampleWeights(RankingSample rankingSample) {
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected Sample getSubLearnerSample() {
        Arrays.fill(this.residuals, 0, this.curTrainSet.size, 0.0d);
        Arrays.fill(this.denomWeights, 0, this.curTrainSet.size, 0.0d);
        RankingSample rankingSample = (RankingSample) this.curTrainSet;
        int size = 1 + (rankingSample.numQueries / this.workers.getSize());
        int i = 0;
        for (int i2 = 0; i2 < this.workers.getSize() && i < rankingSample.numQueries; i2++) {
            this.workers.getTask(i2).init(i, i + Math.min(rankingSample.numQueries - i, size));
            BlockingThreadPoolExecutor.getInstance().execute(this.workers.getTask(i2));
            i += size;
        }
        BlockingThreadPoolExecutor.getInstance().await();
        RankingSample clone = rankingSample.getClone();
        clone.targets = this.residuals;
        setSubLearnerSampleWeights(clone);
        RankingSample clone2 = clone.getClone();
        RankingSample randomSubSample = clone2.getRandomSubSample(this.samplingRate, this.rnd);
        for (int i3 = 0; i3 < randomSubSample.size; i3++) {
            this.subLearnerSampleIndicesInTrainSet[i3] = clone2.indicesInParentSample[randomSubSample.indicesInParentSample[i3]];
        }
        return randomSubSample;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.uci.jforests.learning.LearningModule
    public void onIterationEnd() {
        double[] dArr = null;
        try {
            dArr = ((RankingSample) this.curTrainSet).evaluateByQuery(this.trainPredictions, (RankingEvaluationMetric) ((RankingEvaluationMetric) this.evaluationMetric).getParentMetric());
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.swapScorer.setCurrentIterationEvaluation(this.curIteration, dArr);
        super.onIterationEnd();
    }
}
