/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.eval.ranking;

import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.NDCGEval;
import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.MathUtil;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;

public class URiskAwareEval
extends RankingEvaluationMetric {
    EvaluationMetric parent;
    final double ALPHA;

    public URiskAwareEval(EvaluationMetric _parent, double alpha) {
        super(_parent.largerIsBetter());
        assert (_parent.largerIsBetter());
        this.parent = _parent;
        this.ALPHA = alpha;
    }

    public double measure(double[] predictions, Sample sample) throws Exception {
        RankingSample rankingSample = (RankingSample)sample;
        assert (rankingSample.queryBoundaries.length - 1 == rankingSample.numQueries);
        double[] naturalOrder = URiskAwareEval.computeNaturalOrderScores(predictions.length, rankingSample.queryBoundaries);
        double[] baselinePerQuery = ((RankingEvaluationMetric)this.parent).measureByQuery(naturalOrder, sample);
        double[] perQuery = ((RankingEvaluationMetric)this.parent).measureByQuery(predictions, sample);
        double T1 = 0.0;
        double T2 = 0.0;
        int queryLength = perQuery.length;
        double F_reward = 0.0;
        double F_risk = 0.0;
        for (int i = 0; i < queryLength; ++i) {
            F_reward += Math.max(0.0, perQuery[i] - baselinePerQuery[i]);
            F_risk += Math.max(0.0, baselinePerQuery[i] - perQuery[i]);
            T1 += baselinePerQuery[i];
            T2 += perQuery[i];
        }
        T1 /= (double)queryLength;
        T2 /= (double)queryLength;
        return (F_reward /= (double)queryLength) - (1.0 + this.ALPHA) * (F_risk /= (double)queryLength);
    }

    public RankingEvaluationMetric.SwapScorer getSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        RankingEvaluationMetric.SwapScorer parentModel = ((RankingEvaluationMetric)this.parent).getSwapScorer(targets, boundaries, trunc, labelCounts);
        return new URiskSwapScorer(targets, boundaries, trunc, labelCounts, this.ALPHA, parentModel);
    }

    public double[] measureByQuery(double[] predictions, Sample sample) throws Exception {
        throw new UnsupportedOperationException("Hmmm, not sure how to calculate this one yet");
    }

    public EvaluationMetric getParentMetric() {
        return ((RankingEvaluationMetric)this.parent).getParentMetric();
    }

    public static class TestURiskSwaps {
        @Test
        public void testTwoQueries() throws Exception {
            BlockingThreadPoolExecutor.init(1);
            URiskAwareEval eval = new URiskAwareEval(new NDCGEval(2, 2), 1.0);
            RankingEvaluationMetric.SwapScorer s = ((RankingEvaluationMetric)eval).getSwapScorer(new double[]{0.0, 1.0, 0.0, 1.0}, new int[]{0, 2, 4}, 2, new int[][]{{1, 1, 0, 0, 0}, {1, 1, 0, 0, 0}});
            s.setCurrentIterationEvaluation(0, new double[]{0.1, 0.2});
            System.err.println(s.getDelta(0, 0, 0, 1, 1));
            s = ((RankingEvaluationMetric)eval).getSwapScorer(new double[]{0.0, 1.0}, new int[]{0, 2}, 2, new int[][]{{1, 1, 0, 0, 0}});
            s.setCurrentIterationEvaluation(0, new double[]{0.1});
            System.err.println(s.getDelta(0, 0, 0, 1, 1));
            Assert.assertTrue((boolean)true);
        }
    }

    static class URiskSwapScorer
    extends RankingEvaluationMetric.SwapScorer {
        double alpha;
        RankingEvaluationMetric.SwapScorer parentSwap;
        double[] modelEval;
        double[] baselineEval;

        public URiskSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts, double _alpha, RankingEvaluationMetric.SwapScorer _parent) {
            super(targets, boundaries, trunc, labelCounts);
            this.alpha = _alpha;
            this.parentSwap = _parent;
        }

        public void setCurrentIterationEvaluation(int iteration, double[] nDCG) {
            double meanNDCG = MathUtil.getAvg(nDCG);
            if (iteration == 0) {
                System.err.println("Baseline in iteration 0 has peformance " + meanNDCG);
                this.baselineEval = nDCG;
                this.modelEval = new double[this.baselineEval.length];
            } else {
                System.err.println("Model in iteration " + iteration + " has peformance " + meanNDCG);
                this.modelEval = nDCG;
            }
            System.err.println("Iteration " + iteration + " NDCG=" + Arrays.toString(nDCG));
        }

        public double getDelta(int queryIndex, int betterIdx, int rank_i, int worseIdx, int rank_j) {
            double delta_T;
            double M_m = this.modelEval[queryIndex];
            double M_b = this.baselineEval[queryIndex];
            double delta_M = this.parentSwap.getDelta(queryIndex, betterIdx, rank_i, worseIdx, rank_j);
            double rel_i = this.targets[betterIdx];
            double rel_j = this.targets[worseIdx];
            assert (rel_i >= rel_j);
            if (rank_i > rank_j) assert (delta_M >= 0.0) : "rank_i=" + rank_i + " rank_j=" + rank_j + " delta_M=" + delta_M;
            if (rank_i < rank_j) assert (delta_M <= 0.0) : "rank_i=" + rank_i + " rank_j=" + rank_j + " delta_M=" + delta_M;
            if (M_m <= M_b) {
                if (rel_i > rel_j && rank_i < rank_j) {
                    assert (delta_M < 0.0);
                    delta_T = (1.0 + this.alpha) * delta_M;
                    assert (delta_T < 0.0) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M + " => delta_T=" + delta_T;
                } else {
                    assert (rel_i > rel_j && rank_i > rank_j) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M;
                    assert (delta_M >= 0.0) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M + " rel_i=" + rel_i + " rel_j=" + rel_j + " rank_i=" + rank_i + " rank_j=" + rank_j;
                    if (M_b > M_m + delta_M) {
                        delta_T = (1.0 + this.alpha) * delta_M;
                    } else {
                        assert (M_b <= M_m + delta_M) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M;
                        delta_T = this.alpha * (M_b - M_m) + delta_M;
                    }
                    assert (delta_T > 0.0) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M + " rel_i=" + rel_i + " rel_j=" + rel_j + " rank_i=" + rank_i + " rank_j=" + rank_j + "  => delta_T=" + delta_T;
                }
            } else {
                assert (M_m > M_b) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M;
                if (rel_i > rel_j && rank_i < rank_j) {
                    assert (delta_M <= 0.0) : "rank_i=" + rank_i + " rank_j=" + rank_j + " delta_M=" + delta_M;
                    if (M_b > M_m - Math.abs(delta_M)) {
                        delta_T = this.alpha * (M_m - M_b) - (1.0 + this.alpha) * Math.abs(delta_M);
                    } else {
                        assert (M_b <= M_m - Math.abs(delta_M)) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M;
                        delta_T = delta_M;
                    }
                    assert (delta_T < 0.0) : "M_b=" + M_b + " M_m=" + M_m + " delta_M=" + delta_M + " rel_i=" + rel_i + " rel_j=" + rel_j + " rank_i=" + rank_i + " rank_j=" + rank_j + " => delta_T=" + delta_T;
                } else {
                    assert (rel_i > rel_j && rank_i > rank_j);
                    delta_T = delta_M;
                    assert (delta_T > 0.0);
                }
            }
            return delta_T;
        }
    }
}

