package edu.uci.jforests.learning.trees.regression;

import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.dataset.Feature;
import edu.uci.jforests.dataset.Histogram;
import edu.uci.jforests.learning.trees.CandidateSplitsForLeaf;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLearner;
import edu.uci.jforests.learning.trees.TreeSplit;
import edu.uci.jforests.util.ConfigHolder;

/* loaded from: input_file:edu/uci/jforests/learning/trees/regression/RegressionTreeLearner.class */
public class RegressionTreeLearner extends TreeLearner {
    protected double maxLeafOutput;

    public RegressionTreeLearner() {
        super("RegressionTree");
    }

    @Override // edu.uci.jforests.learning.trees.TreeLearner
    public void init(Dataset dataset, ConfigHolder configHolder, int i) throws Exception {
        super.init(dataset, configHolder, i);
        this.maxLeafOutput = ((RegressionTreesConfig) configHolder.getConfig(RegressionTreesConfig.class)).maxLeafOutput;
    }

    @Override // edu.uci.jforests.learning.trees.TreeLearner
    protected Tree getNewTree() {
        RegressionTree regressionTree = new RegressionTree();
        regressionTree.init(this.maxLeaves, this.maxLeafOutput);
        return regressionTree;
    }

    @Override // edu.uci.jforests.learning.trees.TreeLearner
    protected TreeSplit getNewSplit() {
        return new RegressionTreeSplit();
    }

    @Override // edu.uci.jforests.learning.trees.TreeLearner
    protected CandidateSplitsForLeaf getNewCandidateSplitsForLeaf(int i, int i2) {
        return new RegressionCandidateSplitsForLeaf(i, i2);
    }

    @Override // edu.uci.jforests.learning.trees.TreeLearner
    protected Histogram getNewHistogram(Feature feature) {
        return new RegressionHistogram(feature);
    }

    @Override // edu.uci.jforests.learning.trees.TreeLearner
    protected void setBestThresholdForSplit(TreeSplit treeSplit, Histogram histogram) {
        RegressionHistogram regressionHistogram = (RegressionHistogram) histogram;
        double d = Double.NaN;
        double d2 = Double.NEGATIVE_INFINITY;
        double d3 = -1.0d;
        int i = 0;
        double d4 = 0.0d;
        int i2 = 0;
        double d5 = 0.0d;
        histogram.splittable = false;
        if (this.randomizedSplits) {
            int i3 = 0;
            int i4 = histogram.numValues - 1;
            int i5 = 0;
            while (true) {
                if (i5 >= histogram.numValues - 1) {
                    break;
                }
                d4 += regressionHistogram.perValueSumTargets[i5];
                i2 += histogram.perValueCount[i5];
                if (i2 < this.minInstancesPerLeaf) {
                    i3 = i5;
                } else {
                    if (histogram.totalCount - i2 < this.minInstancesPerLeaf) {
                        i4 = i5 + 1;
                        break;
                    }
                    histogram.splittable = true;
                }
                i5++;
            }
            int nextInt = i3 + this.rand.nextInt(i4 - i3);
            double d6 = 0.0d;
            int i6 = 0;
            double d7 = 0.0d;
            if (histogram.splittable) {
                for (int i7 = 0; i7 < nextInt; i7++) {
                    d6 += regressionHistogram.perValueSumTargets[i7];
                    i6 += histogram.perValueCount[i7];
                    d7 += histogram.perValueWeightedCount[i7];
                }
                double d8 = histogram.totalWeightedCount - d7;
                double d9 = regressionHistogram.sumTargets - d6;
                double d10 = ((d6 * d6) / d7) + ((d9 * d9) / d8);
                if (d10 > Double.NEGATIVE_INFINITY) {
                    d3 = d7;
                    d = d6;
                    i = nextInt;
                    d2 = d10;
                }
            }
        } else {
            for (int i8 = 0; i8 < histogram.numValues - 1; i8++) {
                i2 += histogram.perValueCount[i8];
                d5 += histogram.perValueWeightedCount[i8];
                d4 += regressionHistogram.perValueSumTargets[i8];
                if (i2 >= this.minInstancesPerLeaf && i2 != 0) {
                    int i9 = histogram.totalCount - i2;
                    if (i9 < this.minInstancesPerLeaf || i9 == 0) {
                        break;
                    }
                    histogram.splittable = true;
                    double d11 = histogram.totalWeightedCount - d5;
                    double d12 = regressionHistogram.sumTargets - d4;
                    double d13 = ((d4 * d4) / d5) + ((d12 * d12) / d11);
                    if (d13 > d2) {
                        d3 = d5;
                        d = d4;
                        i = i8;
                        d2 = d13;
                    }
                }
            }
        }
        Feature feature = this.curTrainSet.dataset.features[treeSplit.feature];
        treeSplit.threshold = feature.upperBounds[i];
        treeSplit.originalThreshold = feature.getOriginalValue(treeSplit.threshold);
        RegressionTreeSplit regressionTreeSplit = (RegressionTreeSplit) treeSplit;
        regressionTreeSplit.leftOutput = d / d3;
        regressionTreeSplit.rightOutput = (regressionHistogram.sumTargets - d) / (histogram.totalWeightedCount - d3);
        treeSplit.gain = d2 - ((regressionHistogram.sumTargets * regressionHistogram.sumTargets) / histogram.totalWeightedCount);
    }
}
