package edu.uci.jforests.learning.trees.regression;

import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeSplit;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;

/* loaded from: input_file:edu/uci/jforests/learning/trees/regression/RegressionTree.class */
public class RegressionTree extends Tree {
    private double[] leafOutputs;
    private double maxLeafOutput;

    public Object clone() {
        RegressionTree regressionTree = new RegressionTree();
        super.copyTo(regressionTree);
        regressionTree.maxLeafOutput = this.maxLeafOutput;
        regressionTree.leafOutputs = new double[this.leftChild.length + 1];
        System.arraycopy(this.leafOutputs, 0, regressionTree.leafOutputs, 0, regressionTree.leafOutputs.length);
        return regressionTree;
    }

    public void init(int i, double d) {
        super.init(i);
        this.leafOutputs = new double[i];
        this.maxLeafOutput = d;
    }

    private void markPresenceOfNodesInSubtree(int i, boolean[] zArr, boolean[] zArr2) {
        if (i < 0) {
            zArr2[i ^ (-1)] = true;
            return;
        }
        zArr[i] = true;
        markPresenceOfNodesInSubtree(this.leftChild[i], zArr, zArr2);
        markPresenceOfNodesInSubtree(this.rightChild[i], zArr, zArr2);
    }

    public void normalizeNodeNames() {
        int length = this.leftChild.length + 1;
        boolean[] zArr = new boolean[length - 1];
        boolean[] zArr2 = new boolean[length];
        markPresenceOfNodesInSubtree(0, zArr, zArr2);
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            if (zArr2[i2]) {
                i++;
            }
        }
        int[] iArr = new int[length - 1];
        int[] iArr2 = new int[i - 1];
        int i3 = 0;
        for (int i4 = 0; i4 < length - 1; i4++) {
            if (zArr[i4]) {
                iArr[i4] = i3;
                iArr2[i3] = i4;
                i3++;
                if (i3 == i - 1) {
                    break;
                }
            }
        }
        int[] iArr3 = new int[i];
        int[] iArr4 = new int[length];
        int i5 = 0;
        for (int i6 = 0; i6 < length; i6++) {
            if (zArr2[i6]) {
                iArr3[i5] = i6;
                iArr4[i6] = i5;
                i5++;
                if (i5 == i) {
                    break;
                }
            }
        }
        int[] iArr5 = new int[i - 1];
        int[] iArr6 = new int[i - 1];
        for (int i7 = 0; i7 < length - 1; i7++) {
            if (zArr[i7]) {
                int i8 = this.leftChild[i7];
                if (i8 < 0) {
                    iArr5[iArr[i7]] = iArr4[i8 ^ (-1)] ^ (-1);
                } else {
                    iArr5[iArr[i7]] = iArr[i8];
                }
                int i9 = this.rightChild[i7];
                if (i9 < 0) {
                    iArr6[iArr[i7]] = iArr4[i9 ^ (-1)] ^ (-1);
                } else {
                    iArr6[iArr[i7]] = iArr[i9];
                }
            }
        }
        for (int i10 = 0; i10 < i - 1; i10++) {
            this.leftChild[i10] = iArr5[i10];
            this.rightChild[i10] = iArr6[i10];
            this.splitFeatures[i10] = this.splitFeatures[iArr2[i10]];
            this.thresholds[i10] = this.thresholds[iArr2[i10]];
        }
        for (int i11 = 0; i11 < i; i11++) {
            this.leafOutputs[i11] = this.leafOutputs[iArr3[i11]];
        }
        this.numLeaves = i;
    }

    public double getLeafOutput(int i) {
        return this.leafOutputs[i];
    }

    public void setLeafOutput(int i, double d) {
        if (this.maxLeafOutput > 0.0d) {
            if (d > this.maxLeafOutput) {
                d = this.maxLeafOutput;
            } else if (d < (-this.maxLeafOutput)) {
                d = -this.maxLeafOutput;
            }
        }
        this.leafOutputs[i] = d;
    }

    public void multiplyLeafOutputs(double d) {
        if (d == 1.0d) {
            return;
        }
        for (int i = 0; i < this.numLeaves; i++) {
            setLeafOutput(i, this.leafOutputs[i] * d);
        }
    }

    public void incrementLeafOutputs(double d) {
        if (d == 0.0d) {
            return;
        }
        for (int i = 0; i < this.numLeaves; i++) {
            setLeafOutput(i, this.leafOutputs[i] + d);
        }
    }

    public double getOutput(Dataset dataset, int i) {
        return this.leafOutputs[getLeaf(dataset, i)];
    }

    public double[] getOutputs(Dataset dataset) {
        double[] dArr = new double[dataset.numInstances];
        for (int i = 0; i < dataset.numInstances; i++) {
            dArr[i] = getOutput(dataset, i);
        }
        return dArr;
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    public int split(int i, TreeSplit treeSplit) {
        int split = super.split(i, treeSplit);
        RegressionTreeSplit regressionTreeSplit = (RegressionTreeSplit) treeSplit;
        this.leafOutputs[i] = regressionTreeSplit.leftOutput;
        this.leafOutputs[this.numLeaves - 1] = regressionTreeSplit.rightOutput;
        return split;
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    public void loadCustomData(String str) throws Exception {
        this.leafOutputs = ArraysUtil.loadDoubleArrayFromLine(removeXmlTag(str, "LeafOutputs"), this.numLeaves);
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    protected void addCustomData(String str, StringBuilder sb) {
        StringBuilder sb2 = new StringBuilder();
        for (int i = 0; i < this.numLeaves; i++) {
            sb2.append(" " + this.leafOutputs[i]);
        }
        sb.append("\n" + str + "\t<LeafOutputs>" + sb2.toString().trim() + "</LeafOutputs>");
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    public void backfit(Sample sample) {
        double[] dArr = new double[this.numLeaves];
        double[] dArr2 = new double[this.numLeaves];
        for (int i = 0; i < sample.size; i++) {
            int leaf = getLeaf(sample.dataset, sample.indicesInDataset[i]);
            dArr[leaf] = dArr[leaf] + (sample.targets[i] * sample.weights[i]);
            dArr2[leaf] = dArr2[leaf] + sample.weights[i];
        }
        boolean z = false;
        double[] dArr3 = new double[this.numLeaves - 1];
        int[] iArr = new int[this.numLeaves - 1];
        for (int i2 = 0; i2 < this.numLeaves; i2++) {
            if (dArr2[i2] > 0.0d) {
                double d = dArr[i2] / dArr2[i2];
                setLeafOutput(i2, d);
                int parent = getParent(i2 ^ (-1));
                while (true) {
                    int i3 = parent;
                    if (i3 >= 0) {
                        dArr3[i3] = dArr3[i3] + d;
                        iArr[i3] = (int) (iArr[i3] + dArr2[i2]);
                        parent = getParent(i3);
                    }
                }
            } else {
                z = true;
            }
        }
        if (z) {
            for (int i4 = 0; i4 < this.numLeaves; i4++) {
                if (dArr2[i4] == 0.0d) {
                    int parent2 = getParent(i4 ^ (-1));
                    while (true) {
                        int i5 = parent2;
                        if (i5 < 0) {
                            break;
                        }
                        if (iArr[i5] > 0) {
                            setLeafOutput(i4, dArr3[i5] / iArr[i5]);
                            break;
                        }
                        parent2 = getParent(i5);
                    }
                }
            }
        }
    }
}
