Closed
Description
There has been series of changes to enhance the training and cross validation API in python/R, example of these changes include:
- Early stopping based on the statistics.
- Whether save the results from evaluation in cross validation or training.
- Whether print the results from evaluation.
- Whether save and return the best model from cv or training.
- Adapt learning rate during training.
Currently, each of these proposals involves an API change on core training API. One argument need to be added to each of these requirements. We need to use a better way to handle these issues, otherwise the training API will become extremely hard to maintain.
Use Callbacks to Handle These Cases
def early_stop_maximize(round, metric, verbose=True):
"""Example, Early stopping to maximize metric,
"""
Info = namedtuple(["best_score", "best_score_i"])
info = Info(best_score=float(-inf), best_score_i=0);
def callback(iteration, booster, evaluation_results):
""" Callback function to do early stop.
iteration: int
Current iteration number, equal to total number of trees so far.
If continue from existing model,
booster: Booster
Current booster model
evaluation_results: list of (str, float), evaluation results from watchlist.
"""
score = dict(evaluation_results)[metric]
if score > info.best_score:
info.best_score = score
info.best_score_i = iteration
if iteration - info.best_score_i > round:
booster.best_iteration = info.best_score_i
if verbose:
sys.stderr.write("Stopping at round %d" % iteration)
raise StopTraining()
return callback
def train(param, num_boost_round, callbacks):
....
for i in range(num_boost_round):
bst.update()
try:
for callback in callbacks:
callback(i, bst, evaluation_results)
except StopTraining:
break
# call training
bst = train(param, num_boost_round, callbacks=early_stop_maximize(3, 'test-auc'));
TODO List
- Add the callback API to training and cv API.
- This include python/R/Julia/JVM
- Add a callback function module to xgboost
- We will only accept improvements to callbacks in the future, and being more careful about training API change.
- Add callbacks to support early_stop, logging, best_model_save
- Use the callbacks to keep backward-compatibility
- For example, when early_stop_rounds is detected, add
early_stop_maximize
to the callback list in the beginning of function - Mark the newly added arguments as deprecated, and give a deprecation warning to ask user to use callback API
- We will consider remove some of the not so import arguments after two major release.
- For example, when early_stop_rounds is detected, add