/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.classification;

import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.learning.LearningUtils;
import edu.uci.jforests.learning.boosting.GradientBoosting;
import edu.uci.jforests.learning.boosting.GradientBoostingConfig;
import edu.uci.jforests.learning.trees.LeafInstances;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ConfigHolder;
import java.util.Arrays;

public class GradientBoostingBinaryClassifier
extends GradientBoosting {
    protected double[] balancingFactors;
    protected double[] prob;
    protected double[] validProb;
    protected double[] weights;
    private int[] subLearnerSampleIndicesInTrainSet;
    private boolean imbalanceCostAdjustment;

    public GradientBoostingBinaryClassifier() throws Exception {
        super("GradientBoostingBinaryClassifier");
    }

    public void init(ConfigHolder configHolder, int maxNumTrainInstances, int maxNumValidInstances, EvaluationMetric evaluationMetric) throws Exception {
        super.init(configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
        this.imbalanceCostAdjustment = configHolder.getConfig(GradientBoostingConfig.class).imbalanceCostAdjustment;
        this.prob = new double[maxNumTrainInstances];
        this.validProb = new double[maxNumValidInstances];
        this.weights = new double[maxNumTrainInstances];
        this.subLearnerSampleIndicesInTrainSet = new int[maxNumTrainInstances];
    }

    protected void preprocess() {
        int i;
        if (this.balancingFactors == null || this.balancingFactors.length < this.curTrainSet.size) {
            this.balancingFactors = new double[this.residuals.length];
        }
        int totalPositive = 0;
        int totalNegative = 0;
        for (i = 0; i < this.curTrainSet.size; ++i) {
            if (this.curTrainSet.targets[i] == 0.0) {
                ++totalNegative;
                continue;
            }
            ++totalPositive;
        }
        if (!this.imbalanceCostAdjustment) {
            Arrays.fill(this.balancingFactors, 1.0);
        } else {
            for (i = 0; i < this.curTrainSet.size; ++i) {
                this.balancingFactors[i] = this.curTrainSet.targets[i] > 0.0 ? 1.0 / (double)totalPositive : 1.0 / (double)totalNegative;
            }
        }
        double avg = totalPositive / (totalPositive + totalNegative);
        double initialValue = 0.5 * (Math.log((1.0 + avg) / (1.0 - avg)) / Math.log(2.0));
        Arrays.fill(this.trainPredictions, 0, this.curTrainSet.size, initialValue);
        if (this.curValidSet != null) {
            Arrays.fill(this.validPredictions, 0, this.curValidSet.size, initialValue);
        }
    }

    protected double getValidMeasurement() throws Exception {
        LearningUtils.updateProbabilities(this.validProb, this.validPredictions, this.curValidSet.size);
        return this.curValidSet.evaluate(this.validProb, this.evaluationMetric);
    }

    protected Sample getSubLearnerSample() {
        for (int d = 0; d < this.curTrainSet.size; ++d) {
            int instance = this.curTrainSet.indicesInDataset[d];
            double target = this.curTrainSet.targets[d] == 0.0 ? -1 : 1;
            this.residuals[instance] = 2.0 * target / (1.0 + Math.exp(2.0 * target * this.trainPredictions[d]));
            double responseAbs = Math.abs(this.residuals[instance]);
            this.weights[instance] = responseAbs * (2.0 - responseAbs);
        }
        Sample subLearnerSample = this.curTrainSet.getRandomSubSample(this.samplingRate, this.rnd).getClone();
        subLearnerSample.targets = this.residuals;
        for (int i = 0; i < subLearnerSample.size; ++i) {
            this.subLearnerSampleIndicesInTrainSet[i] = subLearnerSample.indicesInParentSample[i];
        }
        return subLearnerSample;
    }

    protected double getAdjustedOutput(LeafInstances leafInstances) {
        double numerator = 0.0;
        double denomerator = 0.0;
        for (int i = leafInstances.begin; i < leafInstances.end; ++i) {
            int instance = this.subLearnerSampleIndicesInTrainSet[leafInstances.indices[i]];
            numerator += this.residuals[instance] * this.balancingFactors[instance];
            denomerator += this.weights[instance] * this.balancingFactors[instance];
        }
        return this.learningRate * ((numerator + 1.4E-45) / (denomerator + 1.4E-45));
    }

    protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
        LeafInstances leafInstances = new LeafInstances();
        for (int l = 0; l < tree.numLeaves; ++l) {
            treeLeafInstances.loadLeafInstances(l, leafInstances);
            ((RegressionTree)tree).setLeafOutput(l, this.getAdjustedOutput(leafInstances));
        }
    }

    protected void postProcessScores() {
        LearningUtils.updateProbabilities(this.prob, this.trainPredictions, this.curTrainSet.size);
    }
}

