[RFC] Modularize the tree class in both Python and Cython to enable easy extensions #24000
Description
Describe the workflow you want to enable
As we are waiting for a reviewer to review #22754 , @thomasjpfan suggested we just move forward with our goals of creating a package of more exotic tree splits. E.g. https://arxiv.org/abs/1909.11799. While we wait for reviewers, the suggestion is to make a package within scikit-learn-contrib
.
Although we would like for #22754 to be eventually merged into scikit-learn, we understand reviewer backlog is an issue. To move forward while reviews occur, we would need to subclass existing scikit-learn code. Ideally, we would like to introduce minor refactoring changes that would make this task significantly easier.
We would like to subclass directly from scikit-learn without requiring us to keep an up-to-date fork of scikit-learn with all the bug fixes and maintenance that the dev team here does. We can limit this if we modularize the Python/Cython functions inside the sklearn/tree
module.
Describe your proposed solution
I am proposing two refactoring modifications that have no impact on the performance of the current tree estimators in scikit-learn.
- Refactor the
BaseDecisionTree
Python class to have the following functions that can be overridden in a subclass:
_set_tree_class
: sets the Tree Cython class that the Python API uses_set_splitter
: sets the Splitter Cython class that the Python API uses
For example, this makes the subclassing of BaseDecisionTree cleaner:
scikit-learn/sklearn/tree/_classes.py
Lines 410 to 416 in de06afa
- Refactor the
Tree
Cython class to have the following functions:
_set_node_values
: transfers split node values to the storage node_compute_feature_value
: uses the storage node and the input data to compute the feature value to split on
For example, see
scikit-learn/sklearn/tree/_tree.pyx
Lines 770 to 787 in de06afa
- Refactor the
TreeBuilder
Cython class to pass around aSplit
pointer, rather than the struct itself
This will enable someone to use C-level functions to pass around another struct with a similar structure as Split
.
For example, see
scikit-learn/sklearn/tree/_tree.pyx
Line 193 in de06afa
Describe alternatives you've considered, if relevant
Alternatives would require maintaining a copy of the sklearn/tree
module and keep it up-to-date w/ sklearn changes. If this was just one Cython file, I would say it is possible, but the necessary ingredients span some of the underlying private API, making this a very time-consuming task. Introducing modularity into the private API that does not impact existing performance, therefore seems to be the best path forward?
Moreover, by introducing these refactoring changes, #22754 has a smaller diff and lower-cost to review.
Additional context
#22754 demonstrates that there is no performance regression, or issues w/ existing DecisionTree, or RandomForest when introducing these changes.