Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT Refactor Splitter into a BaseSplitter #25101

Open
wants to merge 61 commits into
base: main
Choose a base branch
from

Conversation

jshinm
Copy link
Contributor

@jshinm jshinm commented Dec 2, 2022

Reference Issues/PRs

Fixes #24990

What does this implement/fix? Explain your changes.

Adds the BaseSplitter class as an abstract class to be inherited by Splitter class which is now modularized without assuming "supervised learning" setting. To achieve this, moderate refactoring was done on Criterion by separating resetting of the index pointers from initialization process by creating set_sample_pointers method that is called by child classes of the Criterion.

This change is backwards compatible.

ASV Benchmark result

(ndd-split) jshinm@jshinm-OMEN-by-HP-Laptop-16-b0xxx:~/Desktop/workstation/sklearn/asv_benchmarks$ asv compare main split

All benchmarks:

       before           after         ratio
     [743fe8e5]       [b83cbbc7]
     <main>           <split>   

             190M             190M     1.00  ensemble.RandomForestClassifierBenchmark.peakmem_fit('dense', 1)
             423M             423M     1.00  ensemble.RandomForestClassifierBenchmark.peakmem_fit('sparse', 1)
             190M             190M     1.00  ensemble.RandomForestClassifierBenchmark.peakmem_predict('dense', 1)
             407M             407M     1.00  ensemble.RandomForestClassifierBenchmark.peakmem_predict('sparse', 1)
       4.93±0.02s       5.08±0.01s     1.03  ensemble.RandomForestClassifierBenchmark.time_fit('dense', 1)
       6.30±0.01s          6.32±0s     1.00  ensemble.RandomForestClassifierBenchmark.time_fit('sparse', 1)
        131±0.9ms        131±0.7ms     0.99  ensemble.RandomForestClassifierBenchmark.time_predict('dense', 1)
          847±2ms          855±4ms     1.01  ensemble.RandomForestClassifierBenchmark.time_predict('sparse', 1)
  0.7552784412549299  0.7552784412549299     1.00  ensemble.RandomForestClassifierBenchmark.track_test_score('dense', 1)
  0.8656423941766682  0.8656423941766682     1.00  ensemble.RandomForestClassifierBenchmark.track_test_score('sparse', 1)
  0.9961421915584339  0.9961421915584339     1.00  ensemble.RandomForestClassifierBenchmark.track_train_score('dense', 1)
  0.9996123288718864  0.9996123288718864     1.00  ensemble.RandomForestClassifierBenchmark.track_train_score('sparse', 1)

Test Machine Spec

os [Linux 5.15.0-56-generic]
arch [x86_64]
cpu [11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz]
num_cpu [16]
ram [65483276]

Any other comments?

@jshinm
Copy link
Contributor Author

jshinm commented Dec 12, 2022

@adam2392 this builds and passes all the tree test from pytest. Though, we get some errors with k-means which I'm still trying to troubleshoot

@adam2392
Copy link
Member

What is the error from kmeans?

@jshinm
Copy link
Contributor Author

jshinm commented Dec 12, 2022

Turns out if you accidently put one extra " in your docstring, sklearn won't build.. With this I had to restart my computer which wiped my terminal, so I'm re-running the pytest now 😃. It should be done in a few mins

@jshinm
Copy link
Contributor Author

jshinm commented Dec 12, 2022

@adam2392 The rest of the log is in [this doc]

sklearn/cluster/tests/test_k_means.py:390: AssertionError
_____________ test_kmeans_verbose[0-elkan] _____________

algorithm = 'elkan', tol = 0
capsys = <_pytest.capture.CaptureFixture object at 0x7fe92c5295e0>

    @pytest.mark.parametrize("algorithm", ["lloyd", "elkan"])
    @pytest.mark.parametrize("tol", [1e-2, 0])
    def test_kmeans_verbose(algorithm, tol, capsys):
        # Check verbose mode of KMeans for better coverage.
        X = np.random.RandomState(0).normal(size=(5000, 10))
    
        KMeans(
            algorithm=algorithm,
            n_clusters=n_clusters,
            random_state=42,
            init="random",
            n_init=1,
            tol=tol,
            verbose=1,
        ).fit(X)
    
        captured = capsys.readouterr()
    
        assert re.search(r"Initialization complete", captured.out)
        assert re.search(r"Iteration [0-9]+, inertia", captured.out)
    
        if tol == 0:
>           assert re.search(r"strict convergence", captured.out)
E           AssertionError: assert None
E            +  where None = <function search at 0x7fe9c8b85160>('strict convergence', 'Initialization complete\nIteration 0, inertia 64759.062154957086\nIteration 1, inertia 46751.170526691916\nIteration ... 297, inertia 45634.20523994775\nIteration 298, inertia 45040.45016538214\nIteration 299, inertia 47499.921554245804\n')
E            +    where <function search at 0x7fe9c8b85160> = re.search
E            +    and   'Initialization complete\nIteration 0, inertia 64759.062154957086\nIteration 1, inertia 46751.170526691916\nIteration ... 297, inertia 45634.20523994775\nIteration 298, inertia 45040.45016538214\nIteration 299, inertia 47499.921554245804\n' = CaptureResult(out='Initialization complete\nIteration 0, inertia 64759.062154957086\nIteration 1, inertia 46751.170526...rtia 45634.20523994775\nIteration 298, inertia 45040.45016538214\nIteration 299, inertia 47499.921554245804\n', err='').out

sklearn/cluster/tests/test_k_means.py:390: AssertionError
__________________ test_warm_start[1] __________________

seed = 1

    @pytest.mark.filterwarnings("ignore:.*did not converge.*")
    @pytest.mark.parametrize("seed", (0, 1, 2))
    def test_warm_start(seed):
        random_state = seed
        rng = np.random.RandomState(random_state)
        n_samples, n_features, n_components = 500, 2, 2
        X = rng.rand(n_samples, n_features)
    
        # Assert the warm_start give the same result for the same number of iter
        g = GaussianMixture(
            n_components=n_components,
            n_init=1,
            max_iter=2,
            reg_covar=0,
            random_state=random_state,
            warm_start=False,
        )
        h = GaussianMixture(
            n_components=n_components,
            n_init=1,
            max_iter=1,
            reg_covar=0,
            random_state=random_state,
            warm_start=True,
        )
    
        g.fit(X)
        score1 = h.fit(X).score(X)
        score2 = h.fit(X).score(X)
    
>       assert_almost_equal(g.weights_, h.weights_)
E       AssertionError: 
E       Arrays are not almost equal to 7 decimals
E       
E       Mismatched elements: 2 / 2 (100%)
E       Max absolute difference: 0.01287114
E       Max relative difference: 0.0258317
E        x: array([0.5111401, 0.4888599])
E        y: array([0.498269, 0.501731])

sklearn/mixture/tests/test_gaussian_mixture.py:841: AssertionError
= 15 failed, 24121 passed, 3119 skipped, 79 xfailed, 43 xpassed, 2437 warnings in 511.97s (0:08:31) =

@adam2392
Copy link
Member

Seems unrelated. You did

Make clean

And then rebuild?

@jshinm
Copy link
Contributor Author

jshinm commented Dec 12, 2022

@adam2392 I built from scratch (i.e., creating new conda env) when I benchmarked, but for these, I just worked on my env (i.e., without creating a new env every time I build). So I just created a new env to see, and it seems like it's throwing the same error 😞

I'm not sure if it's relevant, but I forked from your crit branch, and if I'm not mistaken, you updated the branch after I forked. Could I bring my branch up to date with crit to see if resolves?

Update - I tested by bringing the branch up to date with crit, but I still get the same error

@adam2392
Copy link
Member

adam2392 commented Dec 12, 2022

Sure. Could've been resolved on main. It looks like there's something going on cuz most of the CI checks work.

@jshinm
Copy link
Contributor Author

jshinm commented Dec 12, 2022

@adam2392 I just ran the same test on my other computer, and it's able to pass all test without any errors (both current branch and main). So I deleted the whole folder on my work computer and reinstalled everything, and now it's able to pass all test. I think there might have been some path twist on my conda env

Tl; dr - It's fixed!

= 24136 passed, 3119 skipped, 79 xfailed, 43 xpassed, 2434 warnings in 524.07s (0:08:44) =

@jshinm
Copy link
Contributor Author

jshinm commented Dec 12, 2022

oh @adam2392 should I also change this PR to Ready for review from a draft for @jjerphan's review?

@jshinm jshinm marked this pull request as ready for review December 12, 2022 19:08
@adam2392
Copy link
Member

@jshinm please update the PR description summarizing the changes and also add the results of running asv benchmarks into the PR description.

@adam2392
Copy link
Member

@jjerphan this looks ready for an initial look if you're avail.

regarding LOC diff, the criterion PR needs to be merged first for this PR to make sense.

Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor docstring changes.

sklearn/tree/_criterion.pyx Outdated Show resolved Hide resolved
sklearn/tree/_criterion.pyx Outdated Show resolved Hide resolved
@jjerphan
Copy link
Member

jjerphan commented Jan 2, 2023

Happy New Year, @adam2392 and @jshinm.

Thank you for reporting the benchmarks' results.

I would wait for #24678 to merged before pursuing this PR so as to not go back and forth between changes.

If you want you can in the meantime:

  • rerun the benchmarks on the last commits
  • make sure that doc strings lines added or modified by this PR are at most 88 characters long

@jshinm
Copy link
Contributor Author

jshinm commented Jan 2, 2023

Happy New Year, @jjerphan and @adam2392! And thank you for your kind comment :)

I'll get started on those while we are waiting.

Signed-off-by: Adam Li <[email protected]>
Signed-off-by: Adam Li <[email protected]>
Signed-off-by: Adam Li <[email protected]>
@adam2392
Copy link
Member

adam2392 commented Jan 24, 2023

I've been doing some testing downstream to test how trees can be refactored. I realized that due to the inherent linkage between Splitter and Criterion, the current code would not work in practice. I will add here for reference. See: cython/cython#5226 (comment).

I think the issue is masked for some reason here possibly because TreeBuilder still assumes a Splitter, but overall it's not the "correct design" because a splitter that inherits from BaseSplitter and uses a Criterion object would not work properly. A correct design would be to have BaseSplitter not have any notion of a BaseCriterion/Criterion object. I implemented in the downstream PR #25448 to show that this change works.

sample_code.txt

@jjerphan
Copy link
Member

Thanks for the pointers. I will come soon to this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MAINT Split Splitter into a BaseSplitter and a Splitter subclass to allow easier inheritance
3 participants