Skip to content

Commit 4ead83f

Browse files
Add functionality for early stopping rounds. (#193)
* add functionality for early stopping * remove version word * evaluation msg into a parsing function and add back evaluation to updateone * Updated the call to updateone! to pass in the watchlist so it can be used by early stopping round logic. * Added comments, additional examples, fixed issues with watchlist ordering as a Dict. * Added functionality to extract the best iteration round with examples. Included additional test case coverage. * Cleaned up some lingering test cases. * Updated doc to include early stopping example. * Added additional info on data types for watchlist * Annotated OrderedDict to be more obvious. * Included using statement for OrderedCollection * Moved log message parsing to update! instead of updateone * Updated documentation and tests. * Altered the XGBoost method definition to reflect exception states for early stopping rounds and watchlist. * Created exception if extract_metric_value could not find a match when parsing XGBoost logs. --------- Co-authored-by: Wilan Wong <[email protected]> Co-authored-by: wilan-wong-1 <[email protected]>
1 parent c365c78 commit 4ead83f

File tree

3 files changed

+305
-13
lines changed

3 files changed

+305
-13
lines changed

docs/src/index.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ Unlike feature data, label data can be extracted after construction of the `DMat
127127
[`XGBoost.getlabel`](@ref).
128128

129129

130+
130131
## Booster
131132
The [`Booster`](@ref) object holds model data. They are created with training data. Internally
132133
this is always a `DMatrix` but arguments will be automatically converted.
@@ -182,3 +183,43 @@ is equivalent to
182183
bst = xgboost((X, y), num_round=10)
183184
update!(bst, (X, y), num_round=10)
184185
```
186+
187+
### Early Stopping
188+
To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the
189+
boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds.
190+
191+
If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be
192+
a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with
193+
more than 1 element.
194+
195+
Similarly, if there is more than one element in eval_metric, by default the last element will be used.
196+
197+
For example:
198+
199+
```julia
200+
using LinearAlgebra
201+
using OrderedCollections
202+
203+
𝒻(x) = 2norm(x)^2 - norm(x)
204+
205+
X = randn(100,3)
206+
y = 𝒻.(eachrow(X))
207+
208+
dtrain = DMatrix((X, y))
209+
210+
X_valid = randn(50,3)
211+
y_valid = 𝒻.(eachrow(X_valid))
212+
213+
dvalid = DMatrix((X_valid, y_valid))
214+
215+
bst = xgboost(dtrain, num_round = 100, eval_metric = "rmse", watchlist = OrderedDict(["train" => dtrain, "eval" => dvalid]), early_stopping_rounds = 5, max_depth=6, η=0.3)
216+
217+
# get the best iteration and use it for prediction
218+
= predict(bst, X_valid, ntree_limit = bst.best_iteration)
219+
220+
using Statistics
221+
println("RMSE from model prediction $(round((mean((ŷ - y_valid).^2).^0.5), digits = 8)).")
222+
223+
# we can also retain / use the best score (based on eval_metric) which is stored in the booster
224+
println("Best RMSE from model training $(round((bst.best_score), digits = 8)).")
225+
```

src/booster.jl

Lines changed: 147 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
"""
32
Booster
43
@@ -50,11 +49,17 @@ mutable struct Booster
5049
# out what the hell is happening, it's never used for program logic
5150
params::Dict{Symbol,Any}
5251

53-
function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict())
52+
# store early stopping information
53+
best_iteration::Union{Int64, Missing}
54+
best_score::Union{Float64, Missing}
55+
56+
function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict(), best_iteration::Union{Int64, Missing}=missing,
57+
best_score::Union{Float64, Missing}=missing)
5458
finalizer(x -> xgbcall(XGBoosterFree, x.handle), new(h, fsn, params))
5559
end
5660
end
5761

62+
5863
"""
5964
setparam!(b::Booster, name, val)
6065
@@ -366,7 +371,6 @@ function updateone!(b::Booster, Xy::DMatrix;
366371
update_feature_names::Bool=false,
367372
)
368373
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
369-
isempty(watchlist) || logeval(b, watchlist, round_number)
370374
_maybe_update_feature_names!(b, Xy, update_feature_names)
371375
b
372376
end
@@ -382,7 +386,6 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
382386
g = convert(Vector{Cfloat}, g)
383387
h = convert(Vector{Cfloat}, h)
384388
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
385-
isempty(watchlist) || logeval(b, watchlist, round_number)
386389
_maybe_update_feature_names!(b, Xy, update_feature_names)
387390
b
388391
end
@@ -422,14 +425,105 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`.
422425
The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided
423426
for custom loss.
424427
"""
425-
function update!(b::Booster, data, a...; num_round::Integer=1, kw...)
428+
function update!(b::Booster, data, a...;
429+
num_round::Integer=1,
430+
watchlist::Any = Dict("train" => data),
431+
early_stopping_rounds::Integer=0,
432+
maximize=false,
433+
kw...,
434+
)
435+
436+
if !isempty(watchlist) && early_stopping_rounds > 0
437+
@info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
438+
best_round = 0
439+
best_score = maximize ? -Inf : Inf
440+
end
441+
426442
for j 1:num_round
427443
round_number = getnrounds(b) + 1
428-
updateone!(b, data, a...; round_number, kw...)
444+
445+
updateone!(b, data, a...; round_number, watchlist, kw...)
446+
447+
# Evaluate if watchlist is not empty
448+
if !isempty(watchlist)
449+
msg = evaliter(b, watchlist, round_number)
450+
@info msg
451+
if early_stopping_rounds > 0
452+
score, dataset, metric = extract_metric_value(msg)
453+
if (maximize && score > best_score || (!maximize && score < best_score))
454+
best_score = score
455+
best_round = j
456+
elseif j - best_round >= early_stopping_rounds
457+
@info(
458+
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
459+
)
460+
# add additional fields to record the best iteration
461+
b.best_iteration = best_round
462+
b.best_score = best_score
463+
return b
464+
end
465+
end
466+
end
429467
end
430468
b
431469
end
432470

471+
472+
473+
"""
474+
extract_metric_value(msg, dataset=nothing, metric=nothing)
475+
476+
Extracts a numeric value from a message based on the specified dataset and metric.
477+
If dataset or metric is not provided, the function will automatically find the last
478+
mentioned dataset or metric in the message.
479+
480+
# Arguments
481+
- `msg::AbstractString`: The message containing the numeric values.
482+
- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`).
483+
- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`).
484+
485+
# Returns
486+
- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`.
487+
488+
# Examples
489+
```julia
490+
msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874"
491+
492+
# Without specifying dataset and metric
493+
value_without_params = extract_metric_value(msg)
494+
println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle")
495+
496+
# With specifying dataset and metric
497+
value_with_params = extract_metric_value(msg, "train", "rmsle")
498+
println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle")
499+
"""
500+
501+
function extract_metric_value(msg, dataset=nothing, metric=nothing)
502+
if isnothing(dataset)
503+
# Find the last mentioned dataset - whilst retaining order
504+
datasets = unique([m.match for m in eachmatch(r"\w+(?=-)", msg)])
505+
dataset = last(collect(datasets))
506+
end
507+
508+
if isnothing(metric)
509+
# Find the last mentioned metric - whilst retaining order
510+
metrics = unique([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
511+
metric = last(collect(metrics))
512+
end
513+
514+
pattern = Regex("$dataset-$metric:([\\d.]+)")
515+
516+
match_result = match(pattern, msg)
517+
518+
if match_result != nothing
519+
parsed_value = parse(Float64, match_result.captures[1])
520+
return parsed_value, dataset, metric
521+
end
522+
523+
# there was no match result - should error out
524+
error("No match found for pattern: $dataset-$metric in message: $msg")
525+
end
526+
433527
"""
434528
xgboost(data; num_round=10, watchlist=Dict(), kw...)
435529
xgboost(data, ℓ′, ℓ″; kw...)
@@ -439,7 +533,19 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an
439533
followed by [`update!`](@ref) for `nrounds`.
440534
441535
`watchlist` is a dict the keys of which are strings giving the name of the data to watch
442-
and the values of which are [`DMatrix`](@ref) objects containing the data.
536+
and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict
537+
when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the
538+
correct and intended dataset to perform early stop.
539+
540+
`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at
541+
least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset
542+
to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the
543+
last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if
544+
`early_stopping_rounds` is enabled.
545+
546+
`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
547+
When it is false, it means the smaller the evaluation score the better. When set to true,
548+
the larger the evaluation score the better.
443549
444550
All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model
445551
training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for
@@ -450,23 +556,51 @@ See [`updateone!`](@ref) for more details.
450556
451557
## Examples
452558
```julia
559+
# Example 1: Basic usage of XGBoost
453560
(X, y) = (randn(100,3), randn(100))
454561
455-
b = xgboost((X, y), 10, max_depth=10, η=0.1)
562+
b = xgboost((X, y), num_round=10, max_depth=10, η=0.1)
456563
457564
ŷ = predict(b, X)
565+
566+
# Example 2: Using early stopping (using a validation set) with a watchlist
567+
dtrain = DMatrix((randn(100,3), randn(100)))
568+
dvalid = DMatrix((randn(100,3), randn(100)))
569+
570+
watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid])
571+
572+
b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1)
573+
574+
# note that ntree_limit in the predict function helps assign the upper bound for iteration_range in the XGBoost API 1.4+
575+
ŷ = predict(b, dvalid, ntree_limit = b.best_iteration)
458576
```
459577
"""
460578
function xgboost(dm::DMatrix, a...;
461-
num_round::Integer=10,
462-
watchlist=Dict("train"=>dm),
463-
kw...
464-
)
579+
num_round::Integer=10,
580+
watchlist::AbstractDict = Dict("train" => dm),
581+
early_stopping_rounds::Integer=0,
582+
maximize=false,
583+
kw...
584+
)
585+
465586
Xy = DMatrix(dm)
466587
b = Booster(Xy; kw...)
588+
589+
# We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1
590+
if isa(watchlist, Dict)
591+
if early_stopping_rounds > 0 && length(watchlist) > 1
592+
error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.")
593+
end
594+
end
595+
596+
if isempty(watchlist) && early_stopping_rounds > 0
597+
error("Watchlist must be supplied if early_stopping_rounds is enabled.")
598+
end
599+
467600
isempty(watchlist) || @info("XGBoost: starting training.")
468-
update!(b, Xy, a...; num_round, watchlist)
601+
update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize)
469602
isempty(watchlist) || @info("Training rounds complete.")
470603
b
471604
end
605+
472606
xgboost(data, a...; kw...) = xgboost(DMatrix(data), a...; kw...)

0 commit comments

Comments
 (0)