-
-
Notifications
You must be signed in to change notification settings - Fork 25.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some API comments.
Some questions for the novice:
- does calling
prune_tree
with the same alpha repeatedly return the same tree? - does calling
prune_tree
with increasing alpha return a strict sub-tree?
sklearn/tree/tree.py
Outdated
@@ -510,6 +515,110 @@ def decision_path(self, X, check_input=True): | |||
X = self._validate_X_predict(X, check_input) | |||
return self.tree_.decision_path(X) | |||
|
|||
def prune_tree(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be called after fit automatically to facilitate cross validation etc.
I wonder if this should instead be a public function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, prune_tree
is a public function that is called at the end of fit
. It should work with our cross validation classes/functions.
As long as the original tree is the same, using the same alpha will return the same tree. I will add a test for this behavior.
When alpha gets high enough, the entire tree can be pruned, leaving just the root node. |
############################################################################### | ||
# Plot training and test scores vs alpha | ||
# -------------------------------------- | ||
# Calcuate and plot the the training scores and test accuracy scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few typos (I'll document myself on tree pruning and will try to provide mode in-depth review later)
- Calcuate
- the the
- above "a smaller trees"
also I think you should avoid the `:math:`
notation in comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The :math:
notation is currently used in other examples such as: https://github.com/scikit-learn/scikit-learn/blob/master/examples/svm/plot_svm_scale_c.py. Are we discouraging the usage of :math:
in our examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's OK in the docstrings since it will be rendered like regular rst by sphinx, but in the comments it is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These comments are rendered into html: https://42001-843222-gh.circle-artifacts.com/0/doc/auto_examples/tree/plot_cost_complexity_pruning.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh ok I didn't know it worked like that, sorry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first pass of cosmestic comments.
@thomasjpfan it seems to me that in general, and unless there's a compelling reason not to, sklearn code uses (potentially long) descriptive variable names.
For example par_idx
could be renamed to parent_idx
.
Same for cur_alpha
, cur_idx
, etc.
sklearn/tree/tree.py
Outdated
|
||
# bubble up values to ancestor nodes | ||
for idx in leaf_idicies: | ||
cur_R = r_node[idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid upper-case in variable names (same for R_diff)
sklearn/tree/tree.py
Outdated
leaves_in_subtree = np.zeros(shape=n_nodes, dtype=np.uint8) | ||
|
||
stack = [(0, -1)] | ||
while len(stack) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while stack
is more pythonic (same below)
sklearn/tree/tree.py
Outdated
|
||
stack = [(0, -1)] | ||
while len(stack) > 0: | ||
node_id, parent = stack.pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
node_idx
to stay consistent with the rest of the function.
I would also suggest parent_idx
instead of parent
,
and the parents
array could just be named parent
.
sklearn/tree/tree.py
Outdated
# computes number of leaves in all branches and the overall impurity of | ||
# the branch. The overall impurity is the sum of r_node in its leaves. | ||
n_leaves = np.zeros(shape=n_nodes, dtype=np.int32) | ||
leaf_idicies, = np.where(leaves_in_subtree) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leaf_indicies
sklearn/tree/tree.py
Outdated
r_branch[leaf_idicies] = r_node[leaf_idicies] | ||
|
||
# bubble up values to ancestor nodes | ||
for idx in leaf_idicies: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for leaf_idx in...
?
sklearn/tree/tree.py
Outdated
|
||
# descendants of branch are not in subtree | ||
stack = [cur_idx] | ||
while len(stack) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while stack
sklearn/tree/tree.py
Outdated
inner_nodes[idx] = False | ||
leaves_in_subtree[idx] = 0 | ||
in_subtree[idx] = False | ||
n_left, n_right = child_l[idx], child_r[idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually n_something
denotes a count of a number. Here those are just indices right?
sklearn/tree/tree.py
Outdated
leaves_in_subtree[cur_idx] = 1 | ||
|
||
# updates number of leaves | ||
cur_leaves, n_leaves[cur_idx] = n_leaves[cur_idx], 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would propose
n_pruned_leaves = n_leaves[cur_idx] - 1
n_leaves[cur_idx] = 0
and accordingly update n_leaves[cur_idx]
below
sklearn/tree/tree.py
Outdated
|
||
# bubble up values to ancestors | ||
cur_idx = parents[cur_idx] | ||
while cur_idx != _tree.TREE_LEAF: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit weird to bubble up to a leaf.
Whatever you're comparing to here should explicitly be the same value as what you used for defining the root's parent above (stack = [(0, -1)]
)
I would simply use while cur_idx != -1:
Some random thoughts:
|
@jnothman , to add to @thomasjpfan answers:
The procedure is deterministic so calling
A subtree yes, but not necessarily a strict one: with |
sklearn/tree/tree.py
Outdated
in_subtree = np.ones(shape=n_nodes, dtype=np.bool) | ||
|
||
cur_alpha = 0 | ||
while cur_alpha < self.alpha: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 thoughts:
- on the resources that I read (here and here), the very first pruning step is to remove all the pure leaves (equivalent to using alpha=0 apparently). This is not done here since
cur_alpha
is immediately overwritten. I wonder if this is done by default in the tree growing algorithm. - As you check for
cur_alpha < self.alpha
andcur_alpha
is computed before the tree is pruned in the loop, this means that thealpha
of the returned pruned tree will be greater thanself.alpha
. It would seem more natural to me to return a tree whosealpha
is less thanself.alpha
. In any case we would need to explain howalpha
is used in the docs, something like "subtrees whose scores are less thanalpha
are discarded. The score is computed as ..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
When alpha = 0, the number of leaves does not contribute to the cost-complexity measure, which I interpreted as "do not prune". Removing the leaves when alpha=0, will increase the cost-complexity measure.
-
Returning a tree whose
alpha
is less thanself.alpha
makes sense and should documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the leaves when alpha=0, will increase the cost-complexity measure.
It cannot increase the cost-complexity: the first step is to prune (some of) the pure leaves. That is if a node N has 2 child leaves where all the samples in both leaves belong to the same class, then the first step will remove those 2 leaves and make N a leaf (which will still be pure). The process is repeated with N and its sibling if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is if a node N has 2 child leaves where all the samples in both leaves belong to the same class, then the first step will remove those 2 leaves and make N a leaf
This makes sense. I will review the tree building code to see if this can happen.
To prevent future confusion, I want to get on the same page with our definition of a pure leaf. From my understanding, a pure leaf is a leaf whose samples belong to the same class, independent of all other leaves. From reading your response, you consider two leaves to be pure if they are siblings and their samples belong to the same class. Is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my understanding, a pure leaf is a leaf whose samples belong to the same class, independent of all other leaves
I meant this as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just looked at the tree code, I think we can assume that this "first step" is not needed here after all, since a node is made a leaf according to the min_impurity_decrease
(or deprecated min_impurity_split
) parameter.
That is, if a node is pure according to min_impurity_decrease
, it will be made a leaf, and thus the example case I mentioned above (a node with 2 pure leaves) cannot exist.
Where I was going with my questions was the idea of |
|
I'm not sure if we want masking, but further pruning an existing tree might
be reasonable.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Thomas, last minor comments but LGTM!
doc/modules/tree.rst
Outdated
:math:`t`, and its branch, :math:`T_t`, can be equal depending on | ||
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the | ||
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or | ||
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node | |
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|T|-1)`. A non-terminal node |
removed tilde
doc/modules/tree.rst
Outdated
Minimal Cost-Complexity Pruning | ||
=============================== | ||
|
||
Minimal cost-complexity pruning is an algorithm used to prune a tree after it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add ref here to L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees (Chapter 3)
ax.set_xlabel("alpha") | ||
ax.set_ylabel("accuracy") | ||
ax.set_title("Accuracy vs alpha for training and testing sets") | ||
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post") | |
ax.plot(ccp_alphas, train_scores, marker='o', label="train", drawstyle="steps-post") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same below
sklearn/tree/_tree.pyx
Outdated
Tree orig_tree, | ||
DOUBLE_t ccp_alpha): | ||
"""Build a pruned tree from the original tree by transforming the nodes in | ||
leaves_in_subtree into leaves. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this one??
|
||
cdef: | ||
UINT32_t total_items = path_finder.count | ||
np.ndarray ccp_alphas = np.empty(shape=total_items, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping but not super important I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidental comments
doc/modules/tree.rst
Outdated
where :math:`|\tilde{T}|` is the number of terminal nodes in :math:`T` and | ||
:math:`R(T)` is traditionally defined as the total misclassification rate of | ||
the terminal nodes. Alternatively, scikit-learn uses the total sample weighted | ||
impurity of the terminal nodes for :math:`R(T)`. As shown in the previous |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best to say "above" rather than "in the previous section" or link to it so it can withstand change.
doc/modules/tree.rst
Outdated
:math:`t`, and its branch, :math:`T_t`, can be equal depending on | ||
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the | ||
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or | ||
:math:`\alpha_{eff}=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use \frac
since this is not easily readable anyway.
doc/modules/tree.rst
Outdated
=============================== | ||
|
||
Minimal cost-complexity pruning is an algorithm used to prune a tree, described | ||
in Chapter 3 of [BRE]_. This algorithm is parameterized by :math:`\alpha\ge0` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be worth adding a small note to say that this is one method used to avoid over-fitting in trees.
doc/whats_new/v0.22.rst
Outdated
:mod:`sklearn.tree` | ||
................... | ||
|
||
- |Feature| Adds minimal cost complexity pruning to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth mentioning what the public api is... i.e. is it just ccp_alpha?
Failure is unrelated and Joel's comments were addressed so I guess it's good to merge 🎉 Thanks @thomasjpfan ! |
Exciting! Great work!!
|
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682) - Fix deprecated imports
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682) - Fix deprecated imports (scikit-learn/scikit-learn#9250)
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682) - Fix deprecated imports (scikit-learn/scikit-learn#9250) Do not add ccp_alpha to SurvivalTree, because it relies node_impurity, which is not set for SurvivalTree.
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682) - Fix deprecated imports (scikit-learn/scikit-learn#9250) Do not add ccp_alpha to SurvivalTree, because it relies node_impurity, which is not set for SurvivalTree.
- Deprecate presort (scikit-learn/scikit-learn#14907) - Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887) - Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682) - Fix deprecated imports (scikit-learn/scikit-learn#9250) Do not add ccp_alpha to SurvivalTree, because it relies node_impurity, which is not set for SurvivalTree.
I have a question: in the literature[1], the authors first prune the max grown tree and then prune it according to the different 1: L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees. Wadsworth, Belmont, CA, 1984. |
Yes this needs to be done with our cross-validation tools. There is a more interesting way to do this by setting aside some of the training data for validation, in such a way that the tree can automatically find an alpha. This has not been implemented here. |
Thank you for your great work and it benifits me a lot. |
I find some further discussion in "performance learning" (Johannes Fürnkranz, Eyke Hüllermeier). p87-88. |
Using the criterion allows pruning to be extended to regression trees. (the criterion for classification defaults to gini impurity) |
Thank you very much. |
Reference Issues/PRs
Fixes #6557
What does this implement/fix? Explain your changes.
This PR implements Minimal Cost-Complexity Pruning based on L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification and Regression Trees", Wadsworth, Belmont, CA, 1984.
Most of this implementation is the same as the literature. There are two differences:
A cost complexity parameter,
alpha
, was added to__init__
to control cost complexity pruning. The post pruning is done at the end offit
.The code performing Minimal Cost-Complexity Pruning is mostly done in Python. The Python part produces the node ids that will become the leaves of the new subtree. These leaves are passed to a Cython function called
build_pruned_tree
that builds a tree. This was written in Cython since the tree building API is in Cython.In Cython, the Stack class is used to go through the tree. Not all the fields of the StackRecord is used. This is a trade off between the code complexity of adding yet another Stack class, and being a little memory inefficient.
Currently,
prune_tree
is public, which allows for the following use case:If we prefer, we can make
prune_tree
private and not encourage this use case.