/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.eval.ranking;

import edu.uci.jforests.dataset.RankingDataset;
import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.Constants;
import edu.uci.jforests.util.ScoreBasedComparator;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforests.util.concurrency.TaskCollection;
import edu.uci.jforests.util.concurrency.TaskItem;
import java.util.Arrays;

public class NDCGEval
extends RankingEvaluationMetric {
    public static final int MAX_TRUNCATION_LEVEL = 10;
    public static final int GAIN_LEVELS = 5;
    public static double[] GAINS = new double[]{0.0, 1.0, 3.0, 7.0, 15.0};
    public static double[] discounts;
    private TaskCollection<PerQueryNDCGWorker> ndcgWorkers;
    private int evalTruncationLevel;
    private int maxDocsPerQuery;

    public static synchronized void initialize(int maxDocsPerQuery) {
        if (discounts == null || discounts.length < maxDocsPerQuery) {
            discounts = new double[maxDocsPerQuery];
            for (int p = 0; p < maxDocsPerQuery; ++p) {
                NDCGEval.discounts[p] = Constants.LN2 / Math.log(2 + p);
            }
        }
    }

    public NDCGEval(int maxDocsPerQuery, int evalTruncationLevel) throws Exception {
        super(true);
        this.maxDocsPerQuery = maxDocsPerQuery;
        NDCGEval.initialize(maxDocsPerQuery);
        if (evalTruncationLevel > 10) {
            throw new Exception("Evalutation truncation level " + evalTruncationLevel + " is larger than " + 10);
        }
        this.evalTruncationLevel = evalTruncationLevel;
        int numWorkers = BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize();
        this.ndcgWorkers = new TaskCollection();
        for (int i = 0; i < numWorkers; ++i) {
            this.ndcgWorkers.addTask(new PerQueryNDCGWorker());
        }
    }

    public static int[][] getLabelCountsForQueries(double[] labels, int[] boundaries) {
        int numQueries = boundaries.length - 1;
        int[][] labelCounts = new int[numQueries][5];
        for (int q = 0; q < numQueries; ++q) {
            int begin = boundaries[q];
            int end = boundaries[q + 1];
            for (int i = begin; i < end; ++i) {
                if ((int)labels[i] >= 5) {
                    System.err.println("query " + q + " line " + i + " label " + (int)labels[i]);
                    labels[i] = 4.0;
                }
                int[] nArray = labelCounts[q];
                int n = (int)labels[i];
                nArray[n] = nArray[n] + 1;
            }
        }
        return labelCounts;
    }

    public static double[][] getMaxDCGForAllQueriesUptoTruncation(double[] labels, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        double[][] maxDCG = new double[trunc][];
        for (int t = 1; t <= trunc; ++t) {
            maxDCG[t - 1] = NDCGEval.getMaxDCGForAllQueriesAtTruncation(labels, boundaries, t, labelCounts);
        }
        return maxDCG;
    }

    public static double[] getMaxDCGForAllQueriesAtTruncation(double[] labels, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        if (discounts == null) {
            throw new Exception("Not initialized.");
        }
        double[] maxDCG = new double[boundaries.length - 1];
        int[] tempCounts = new int[5];
        for (int q = 0; q < boundaries.length - 1; ++q) {
            int maxTrunc = Math.min(trunc, boundaries[q + 1] - boundaries[q]);
            int topLabel = 4;
            maxDCG[q] = 0.0;
            System.arraycopy(labelCounts[q], 0, tempCounts, 0, 5);
            for (int t = 0; t < maxTrunc; ++t) {
                while (tempCounts[topLabel] == 0 && topLabel > 0) {
                    --topLabel;
                }
                int n = q;
                maxDCG[n] = maxDCG[n] + GAINS[topLabel] * discounts[t];
                int n2 = topLabel;
                tempCounts[n2] = tempCounts[n2] - 1;
            }
        }
        return maxDCG;
    }

    public double[] getNDCGatAllTruncations(double[] predictions, Sample sample, ScoreBasedComparator.TieBreaker tieBreaker) throws Exception {
        if (((RankingDataset)sample.dataset).maxDCG == null) {
            throw new Exception("maxDCG is not initialized for dataset.");
        }
        RankingSample rankingSample = (RankingSample)sample;
        int chunkSize = 1 + rankingSample.numQueries / this.ndcgWorkers.getSize();
        int offset = 0;
        int workerCount = 0;
        for (int i = 0; i < this.ndcgWorkers.getSize() && offset < rankingSample.numQueries; offset += chunkSize, ++i) {
            int endOffset = offset + Math.min(rankingSample.numQueries - offset, chunkSize);
            NDCGWorker worker = this.ndcgWorkers.getTask(i);
            ++workerCount;
            worker.init(rankingSample, predictions, offset, endOffset, tieBreaker);
            BlockingThreadPoolExecutor.getInstance().execute(worker);
        }
        BlockingThreadPoolExecutor.getInstance().await();
        double[] result = new double[10];
        for (int i = 0; i < workerCount; ++i) {
            double[] localResult = this.ndcgWorkers.getTask(i).getResults();
            for (int t = 0; t < 10; ++t) {
                int n = t;
                result[n] = result[n] + localResult[t];
            }
        }
        int t = 0;
        while (t < 10) {
            int n = t++;
            result[n] = result[n] / (double)rankingSample.numQueries;
        }
        return result;
    }

    public double[][] getNDCGatAllTruncationsAllQueries(double[] predictions, Sample sample, ScoreBasedComparator.TieBreaker tieBreaker) throws Exception {
        if (((RankingDataset)sample.dataset).maxDCG == null) {
            throw new Exception("maxDCG is not initialized for dataset.");
        }
        RankingSample rankingSample = (RankingSample)sample;
        int chunkSize = 1 + rankingSample.numQueries / this.ndcgWorkers.getSize();
        int offset = 0;
        int workerCount = 0;
        for (int i = 0; i < this.ndcgWorkers.getSize() && offset < rankingSample.numQueries; offset += chunkSize, ++i) {
            int endOffset = offset + Math.min(rankingSample.numQueries - offset, chunkSize);
            NDCGWorker worker = this.ndcgWorkers.getTask(i);
            ++workerCount;
            worker.init(rankingSample, predictions, offset, endOffset, tieBreaker);
            BlockingThreadPoolExecutor.getInstance().execute(worker);
        }
        BlockingThreadPoolExecutor.getInstance().await();
        double[][] result = new double[10][rankingSample.numQueries];
        offset = 0;
        for (int i = 0; i < workerCount; ++i) {
            double[][] localResult = this.ndcgWorkers.getTask(i).getQueryResults();
            for (int t = 0; t < 10; ++t) {
                System.arraycopy(localResult[t], 0, result[t], offset, localResult[t].length);
            }
            offset += localResult[0].length;
        }
        assert (offset == rankingSample.numQueries);
        return result;
    }

    public double measure(double[] predictions, Sample sample) throws Exception {
        return this.getNDCGatAllTruncations(predictions, sample, ScoreBasedComparator.TieBreaker.ReverseLabels)[this.evalTruncationLevel - 1];
    }

    public double[] measureByQuery(double[] predictions, Sample sample) throws Exception {
        return this.getNDCGatAllTruncationsAllQueries(predictions, sample, ScoreBasedComparator.TieBreaker.ReverseLabels)[this.evalTruncationLevel - 1];
    }

    public RankingEvaluationMetric.SwapScorer getSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        return new NDCGSwapScorer(targets, boundaries, trunc, labelCounts);
    }

    private class NDCGWorker
    extends TaskItem {
        protected int[] permutation;
        protected RankingSample sample;
        protected int beginIdx;
        protected int endIdx;
        protected double[] result;
        protected ScoreBasedComparator comparator;

        public NDCGWorker() {
            this.permutation = new int[NDCGEval.this.maxDocsPerQuery];
            this.comparator = new ScoreBasedComparator();
            this.result = new double[10];
        }

        public void init(RankingSample sample, double[] scores, int beginIdx, int endIdx, ScoreBasedComparator.TieBreaker tieBreaker) {
            this.sample = sample;
            this.beginIdx = beginIdx;
            this.endIdx = endIdx;
            this.comparator.labels = sample.targets;
            this.comparator.scores = scores;
            this.comparator.tieBreaker = tieBreaker;
            Arrays.fill(this.result, 0.0);
        }

        public double[] getResults() {
            return this.result;
        }

        public void run() {
            for (int q = this.beginIdx; q < this.endIdx; ++q) {
                int begin = this.sample.queryBoundaries[q];
                int numDocs = this.sample.queryBoundaries[q + 1] - begin;
                double[][] maxDCG = ((RankingDataset)this.sample.dataset).maxDCG;
                this.comparator.offset = begin;
                for (int d = 0; d < numDocs; ++d) {
                    this.permutation[d] = d;
                }
                ArraysUtil.sort(this.permutation, numDocs, this.comparator);
                if (numDocs > 10) {
                    numDocs = 10;
                }
                try {
                    int t;
                    double dcg = 0.0;
                    if (maxDCG[0][this.sample.queryIndices[q]] == 0.0) {
                        t = 0;
                        while (t < 10) {
                            int n = t++;
                            this.result[n] = this.result[n] + 1.0;
                        }
                        continue;
                    }
                    for (t = 0; t < numDocs; ++t) {
                        if (!((dcg += GAINS[(int)this.sample.targets[begin + this.permutation[t]]] * discounts[t]) > 0.0)) continue;
                        int n = t;
                        this.result[n] = this.result[n] + dcg / maxDCG[t][this.sample.queryIndices[q]];
                    }
                    if (!(dcg > 0.0)) continue;
                    for (t = numDocs; t < 10; ++t) {
                        int n = t;
                        this.result[n] = this.result[n] + dcg / maxDCG[t][this.sample.queryIndices[q]];
                    }
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private class PerQueryNDCGWorker
    extends NDCGWorker {
        private double[][] queryresult;

        public double[][] getQueryResults() {
            return this.queryresult;
        }

        public void init(RankingSample sample, double[] scores, int beginIdx, int endIdx, ScoreBasedComparator.TieBreaker tieBreaker) {
            super.init(sample, scores, beginIdx, endIdx, tieBreaker);
            this.queryresult = new double[10][endIdx - beginIdx];
        }

        public void run() {
            int qid = 0;
            for (int q = this.beginIdx; q < this.endIdx; ++q) {
                int begin = this.sample.queryBoundaries[q];
                int numDocs = this.sample.queryBoundaries[q + 1] - begin;
                double[][] maxDCG = ((RankingDataset)this.sample.dataset).maxDCG;
                this.comparator.offset = begin;
                for (int d = 0; d < numDocs; ++d) {
                    this.permutation[d] = d;
                }
                ArraysUtil.sort(this.permutation, numDocs, this.comparator);
                if (numDocs > 10) {
                    numDocs = 10;
                }
                try {
                    int t;
                    double dcg = 0.0;
                    if (maxDCG[0][this.sample.queryIndices[q]] == 0.0) {
                        for (t = 0; t < 10; ++t) {
                            int n = t;
                            this.result[n] = this.result[n] + 1.0;
                            double[] dArray = this.queryresult[t];
                            int n2 = qid;
                            dArray[n2] = dArray[n2] + 1.0;
                        }
                    } else {
                        for (t = 0; t < numDocs; ++t) {
                            if (!((dcg += GAINS[(int)this.sample.targets[begin + this.permutation[t]]] * discounts[t]) > 0.0)) continue;
                            int n = t;
                            this.result[n] = this.result[n] + dcg / maxDCG[t][this.sample.queryIndices[q]];
                            double[] dArray = this.queryresult[t];
                            int n3 = qid;
                            dArray[n3] = dArray[n3] + dcg / maxDCG[t][this.sample.queryIndices[q]];
                        }
                        if (dcg > 0.0) {
                            for (t = numDocs; t < 10; ++t) {
                                int n = t;
                                this.result[n] = this.result[n] + dcg / maxDCG[t][this.sample.queryIndices[q]];
                                double[] dArray = this.queryresult[t];
                                int n4 = qid;
                                dArray[n4] = dArray[n4] + dcg / maxDCG[t][this.sample.queryIndices[q]];
                            }
                        }
                    }
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                ++qid;
            }
        }
    }

    static final class NDCGSwapScorer
    extends RankingEvaluationMetric.SwapScorer {
        final double[] maxDCG;
        final int[] labels;

        NDCGSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
            super(targets, boundaries, trunc, labelCounts);
            this.maxDCG = NDCGEval.getMaxDCGForAllQueriesAtTruncation(targets, boundaries, trunc, labelCounts);
            this.labels = new int[targets.length];
            for (int i = 0; i < targets.length; ++i) {
                this.labels[i] = (int)targets[i];
            }
            assert (discounts != null);
        }

        public final double getDelta(int queryIndex, int betterIdx, int rank_i, int worseIdx, int rank_j) {
            assert (betterIdx < this.labels.length);
            assert (worseIdx < this.labels.length);
            double queryMaxDcg = this.maxDCG[queryIndex];
            return (GAINS[this.labels[betterIdx]] - GAINS[this.labels[worseIdx]]) * (discounts[rank_j] - discounts[rank_i]) / queryMaxDcg;
        }
    }
}

