package edu.uci.jforests.learning.trees.decision;

import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeSplit;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.MathUtil;

/* loaded from: input_file:edu/uci/jforests/learning/trees/decision/DecisionTree.class */
public class DecisionTree extends Tree {
    private double[][] leafTargetDistributions;
    private int numClasses;

    public Object clone() {
        DecisionTree decisionTree = new DecisionTree();
        super.copyTo(decisionTree);
        decisionTree.leafTargetDistributions = MathUtil.cloneDoubleMatrix(this.leafTargetDistributions);
        return decisionTree;
    }

    public void init(int i, int i2) {
        super.init(i);
        this.numClasses = i2;
        this.leafTargetDistributions = new double[i][i2];
    }

    public double[] getLeafTargetDistribution(int i) {
        return this.leafTargetDistributions[i];
    }

    public void setLeafTargetDistribution(int i, double[] dArr) {
        System.arraycopy(dArr, 0, this.leafTargetDistributions[i], 0, dArr.length);
    }

    public int classify(Dataset dataset, int i) {
        return ArraysUtil.findMaxIndex(this.leafTargetDistributions[getLeaf(dataset, i)]);
    }

    public double[] getDistributionForInstance(Dataset dataset, int i) {
        return this.leafTargetDistributions[getLeaf(dataset, i)];
    }

    public int[] getPredictions(Dataset dataset) {
        int[] iArr = new int[dataset.numInstances];
        for (int i = 0; i < dataset.numInstances; i++) {
            iArr[i] = classify(dataset, i);
        }
        return iArr;
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    public int split(int i, TreeSplit treeSplit) {
        int split = super.split(i, treeSplit);
        DecisionTreeSplit decisionTreeSplit = (DecisionTreeSplit) treeSplit;
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            this.leafTargetDistributions[i][i2] = decisionTreeSplit.leftTargetDist[i2];
            this.leafTargetDistributions[this.numLeaves - 1][i2] = decisionTreeSplit.rightTargetDist[i2];
        }
        normalizeLeafTargetDistributions(i);
        normalizeLeafTargetDistributions(this.numLeaves - 1);
        return split;
    }

    private void normalizeLeafTargetDistributions(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            d += this.leafTargetDistributions[i][i2];
        }
        for (int i3 = 0; i3 < this.numClasses; i3++) {
            double[] dArr = this.leafTargetDistributions[i];
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    public void loadCustomData(String str) throws Exception {
        this.leafTargetDistributions = ArraysUtil.loadDoubleMatrixFromLine(removeXmlTag(str, "LeafTargetDistributions"), this.numLeaves, this.numClasses);
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    protected void addCustomData(String str, StringBuilder sb) {
        StringBuilder sb2 = new StringBuilder();
        for (int i = 0; i < this.numLeaves; i++) {
            for (int i2 = 0; i2 < this.numClasses; i2++) {
                sb2.append(" " + this.leafTargetDistributions[i][i2]);
            }
        }
        sb.append("\n" + str + "\t<LeafTargetDistributions>" + sb2.toString().trim() + "</LeafTargetDistributions>");
    }

    @Override // edu.uci.jforests.learning.trees.Tree
    public void backfit(Sample sample) {
        double[][] dArr = new double[this.numLeaves][this.numClasses];
        for (int i = 0; i < sample.size; i++) {
            double[] dArr2 = dArr[getLeaf(sample.dataset, sample.indicesInDataset[i])];
            int i2 = (int) sample.targets[i];
            dArr2[i2] = dArr2[i2] + sample.weights[i];
        }
        double[] dArr3 = new double[this.numLeaves];
        for (int i3 = 0; i3 < this.numLeaves; i3++) {
            for (int i4 = 0; i4 < this.numClasses; i4++) {
                int i5 = i3;
                dArr3[i5] = dArr3[i5] + dArr[i3][i4];
            }
        }
        boolean z = false;
        double[][] dArr4 = new double[this.numLeaves - 1][this.numClasses];
        for (int i6 = 0; i6 < this.numLeaves; i6++) {
            if (dArr3[i6] > 0.0d) {
                setLeafTargetDistribution(i6, dArr[i6]);
                int parent = getParent(i6 ^ (-1));
                while (true) {
                    int i7 = parent;
                    if (i7 >= 0) {
                        for (int i8 = 0; i8 < this.numClasses; i8++) {
                            double[] dArr5 = dArr4[i7];
                            int i9 = i8;
                            dArr5[i9] = dArr5[i9] + dArr[i6][i8];
                        }
                        parent = getParent(i7);
                    }
                }
            } else {
                z = true;
            }
        }
        if (z) {
            double[] dArr6 = new double[this.numLeaves - 1];
            for (int i10 = 0; i10 < dArr6.length; i10++) {
                for (int i11 = 0; i11 < this.numClasses; i11++) {
                    int i12 = i10;
                    dArr6[i12] = dArr6[i12] + dArr4[i10][i11];
                }
            }
            for (int i13 = 0; i13 < this.numLeaves; i13++) {
                if (dArr3[i13] == 0.0d) {
                    int parent2 = getParent(i13 ^ (-1));
                    while (true) {
                        int i14 = parent2;
                        if (i14 < 0) {
                            break;
                        }
                        if (dArr6[i14] > 0.0d) {
                            setLeafTargetDistribution(i13, dArr4[i14]);
                            break;
                        }
                        parent2 = getParent(i14);
                    }
                }
            }
        }
    }
}
