Skip to content

Commit 22a0239

Browse files
Display Split conditions (#162)
* add nightly ci * display split conditions in Term.Tree, update docs * remove nightly ci (added to the wrong branch) * remove """debugging""" if statement * bump Project to `v2.2.2`
1 parent 6bce63e commit 22a0239

File tree

4 files changed

+73
-60
lines changed

4 files changed

+73
-60
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "XGBoost"
22
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
3-
version = "2.2.1"
3+
version = "2.2.2"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

docs/src/features.md

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,52 @@ will use assigned feature names, for example
2727
```julia
2828
julia> df = DataFrame(randn(10,3), ["kirk", "spock", "bones"])
2929
10×3 DataFrame
30-
Row │ kirk spock bones
31-
│ Float64 Float64 Float64
30+
Row │ kirk spock bones
31+
│ Float64 Float64 Float64
3232
─────┼───────────────────────────────────
33-
10.731406 -0.53631 0.465881
34-
20.553427 -0.787531 -0.838059
35-
31.30724 -2.38111 -1.1979
36-
4 0.0759902 0.418856 1.49618
37-
5-0.426773 -0.32008 -0.773329
38-
6-1.36495 -0.105646 1.08546
39-
7 0.476315 -0.080163 -1.4846
40-
8 0.144403 0.344307 -0.0301839
41-
9 0.593969 0.165502 1.31196
42-
102.15151 0.584925 -0.709128
43-
44-
julia> bst = xgboost((df, randn(10)), 10)
33+
10.663934 -0.419345 -0.489801
34+
21.19064 0.420935 -0.321852
35+
30.713867 0.293724 0.0450463
36+
4-1.3474 -0.402996 1.50831
37+
5-0.458164 0.0399281 -0.83443
38+
6-0.277555 0.149485 0.408656
39+
7-1.79885 -1.1535 0.99213
40+
8-0.177408 -0.818639 0.280188
41+
9-1.26053 -1.60734 2.21421
42+
100.30378 -0.299256 0.384029
43+
44+
julia> bst = xgboost((df, randn(10)), num_round=10)
4545
[ Info: XGBoost: starting training.
46-
[ Info: [1] train-rmse:0.71749003518059951
47-
[ Info: [2] train-rmse:0.57348349389049413
48-
[ Info: [3] train-rmse:0.46118182517533174
49-
[ Info: [4] train-rmse:0.37161911786076596
50-
[ Info: [5] train-rmse:0.29986573085749962
51-
[ Info: [6] train-rmse:0.24238347776088820
52-
[ Info: [7] train-rmse:0.19544715478958452
53-
[ Info: [8] train-rmse:0.15795933989281422
54-
[ Info: [9] train-rmse:0.12805284613811851
55-
[ Info: [10] train-rmse:0.10467078844629517
46+
[ Info: [1] train-rmse:0.57998637329114211
47+
[ Info: [2] train-rmse:0.48232409595403752
48+
[ Info: [3] train-rmse:0.40593080843433427
49+
[ Info: [4] train-rmse:0.34595769369793850
50+
[ Info: [5] train-rmse:0.29282108263987289
51+
[ Info: [6] train-rmse:0.24862819795032731
52+
[ Info: [7] train-rmse:0.21094418685218519
53+
[ Info: [8] train-rmse:0.17903024616536045
54+
[ Info: [9] train-rmse:0.15198720040980171
55+
[ Info: [10] train-rmse:0.12906074380448287
5656
[ Info: Training rounds complete.
5757
╭──── XGBoost.Booster ─────────────────────────────────────────────────────────────────╮
5858
│ Features: ["kirk", "spock", "bones"] │
59+
│ │
60+
│ Parameter Value │
61+
│ ───────────────────────────────── │
62+
│ validate_parameters true
63+
│ │
5964
╰──── boosted rounds: 10 ──────────────────────────────────────────────────────────────╯
6065

6166
julia> importancereport(bst)
62-
╭───────────┬────────────┬──────────┬───────────┬──────────────┬───────────────╮
63-
│ feature │ gain │ weight │ cover │ total_gain │ total_cover │
64-
├───────────┼────────────┼──────────┼───────────┼──────────────┼───────────────┤
65-
"bones"0.229349 17.07.647063.89893130.0
66-
├───────────┼────────────┼──────────┼───────────┼──────────────┼───────────────┤
67-
"spock"0.176391 18.04.77778 3.1750386.0
68-
├───────────┼────────────┼──────────┼───────────┼──────────────┼───────────────┤
69-
"kirk"0.11505513.03.384621.49572 44.0
70-
╰───────────┴────────────┴──────────┴───────────┴──────────────┴───────────────╯
67+
╭───────────┬────────────┬──────────┬───────────┬──────────────┬───────────────╮
68+
│ feature │ gain │ weight │ cover │ total_gain │ total_cover │
69+
├───────────┼────────────┼──────────┼───────────┼──────────────┼───────────────┤
70+
"bones"0.358836 15.08.533335.38254128.0
71+
├───────────┼────────────┼──────────┼───────────┼──────────────┼───────────────┤
72+
"spock"0.157437 16.0 4.75 2.5189976.0
73+
├───────────┼────────────┼──────────┼───────────┼──────────────┼───────────────┤
74+
"kirk"0.012854634.02.911760.43705699.0
75+
╰───────────┴────────────┴──────────┴───────────┴──────────────┴───────────────╯
7176
```
7277
7378
### Tree Inspection
@@ -81,39 +86,43 @@ interface.
8186
```julia
8287
julia> ts = trees(bst)
8388
10-element Vector{XGBoost.Node}:
84-
XGBoost.Node(split_feature="f1")
85-
XGBoost.Node(split_feature="f1")
86-
XGBoost.Node(split_feature="f1")
87-
XGBoost.Node(split_feature="f1")
88-
XGBoost.Node(split_feature="f1")
89-
XGBoost.Node(split_feature="f1")
90-
XGBoost.Node(split_feature="f1")
91-
XGBoost.Node(split_feature="f1")
92-
XGBoost.Node(split_feature="f1")
93-
XGBoost.Node(split_feature="f1")
89+
XGBoost.Node(split_feature="bones")
90+
XGBoost.Node(split_feature="bones")
91+
XGBoost.Node(split_feature="bones")
92+
XGBoost.Node(split_feature="bones")
93+
XGBoost.Node(split_feature="bones")
94+
XGBoost.Node(split_feature="bones")
95+
XGBoost.Node(split_feature="bones")
96+
XGBoost.Node(split_feature="bones")
97+
XGBoost.Node(split_feature="bones")
98+
XGBoost.Node(split_feature="bones")
9499

95100
julia> ts[1]
96101
╭──── XGBoost.Node (id=0, depth=0) ────────────────────────────────────────────────────╮
97102
│ │
98-
│ split_condition yes no nmissing gain cover │
99-
│ ────────────────────────────────────────────────────────────────────────
100-
-0.267610937 1 2 1 0.284702361 10.0
103+
│ split_condition yes no nmissing gain cover
104+
│ ────────────────────────────────────────────────────────────────────────
105+
0.396342576 1 2 1 1.86042714 10.0
101106
│ │
102107
│ XGBoost Tree (from this node) │
103108
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
104109
│ │ │
105-
│ ├── f0 (1) │
106-
│ │ ├── f0 (1) │
107-
│ │ │ ├── (1): XGBoost.Node(leaf=0.042126134) │
108-
│ │ │ └── (2): XGBoost.Node(leaf=-0.0647352263) │
109-
│ │ └── (2): XGBoost.Node(leaf=0.0405130237) │
110-
│ └── (2): XGBoost.Node(leaf=-0.0718128532) │
110+
│ ├── bones < 0.396
111+
│ │ ├── bones < 0.332: XGBoost.Node(leaf=-0.159539297) │
112+
│ │ └── bones 0.332: XGBoost.Node(leaf=-0.0306737479) │
113+
│ └── bones 0.396
114+
│ ├── spock < -0.778
115+
│ │ ├── kirk < -1.53: XGBoost.Node(leaf=-0.0544514731) │
116+
│ │ └── kirk -1.53: XGBoost.Node(leaf=0.00967349485) │
117+
│ └── spock -0.778
118+
│ ├── kirk < -0.812: XGBoost.Node(leaf=0.0550933369) │
119+
│ └── kirk -0.812: XGBoost.Node(leaf=0.228843644) │
111120
╰──── 2 children ──────────────────────────────────────────────────────────────────────╯
112121

113122
julia> using AbstractTrees; children(ts[1])
114123
2-element Vector{XGBoost.Node}:
115-
XGBoost.Node(split_feature="f0")
116-
XGBoost.Node(leaf=-0.0718128532)
124+
XGBoost.Node(split_feature="bones")
125+
XGBoost.Node(split_feature="spock")
117126
```
118127
119128
## Setting a Custom Objective Function

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,6 @@ bst = xgboost((X, y), num_round=20)
179179
```
180180
is equivalent to
181181
```julia
182-
bst = xgboost((X, y), nun_round=10)
182+
bst = xgboost((X, y), num_round=10)
183183
update!(bst, (X, y), num_round=10)
184184
```

src/show.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,21 @@ function importancereport(b::Booster)
9292
end
9393
end
9494

95-
function _tree_display_branch_string(split, j::Integer)
96-
o = "($j)"
97-
isnothing(split) ? o : string(split, " ", o)
95+
96+
function _tree_display_branch_string(node, child_id::Integer)
97+
if node.yes == child_id
98+
string(node.split, " < ", round(node.split_condition, digits=3))
99+
else
100+
string(node.split, "", round(node.split_condition, digits=3))
101+
end
98102
end
99103

100104
function _tree_display(node::Node)
101105
ch = children(node)
102106
if isempty(ch)
103107
sprint(show, node)
104108
else
105-
OrderedDict(_tree_display_branch_string(ch[j].split, j)=>_tree_display(ch[j]) for j 1:length(ch))
109+
OrderedDict(_tree_display_branch_string(node, ch[j].id)=>_tree_display(ch[j]) for j 1:length(ch))
106110
end
107111
end
108112

0 commit comments

Comments
 (0)