Skip to content

[Guideline] Training API Enhancement and Refactor: Use Callbacks #892

Closed
@tqchen

Description

@tqchen

There has been series of changes to enhance the training and cross validation API in python/R, example of these changes include:

  • Early stopping based on the statistics.
  • Whether save the results from evaluation in cross validation or training.
  • Whether print the results from evaluation.
  • Whether save and return the best model from cv or training.
  • Adapt learning rate during training.

Currently, each of these proposals involves an API change on core training API. One argument need to be added to each of these requirements. We need to use a better way to handle these issues, otherwise the training API will become extremely hard to maintain.

Use Callbacks to Handle These Cases

def early_stop_maximize(round, metric, verbose=True):
    """Example, Early stopping to maximize metric,
    """
    Info = namedtuple(["best_score", "best_score_i"]) 
    info = Info(best_score=float(-inf), best_score_i=0);
    def callback(iteration, booster, evaluation_results):
         """ Callback function to do  early stop.
         iteration: int
              Current iteration number, equal to total number of trees so far.
              If continue from existing model, 
         booster: Booster
              Current booster model
         evaluation_results: list of (str, float), evaluation results from watchlist.
         """ 
         score =  dict(evaluation_results)[metric]
         if score > info.best_score:
             info.best_score = score
             info.best_score_i = iteration
         if iteration - info.best_score_i > round:
             booster.best_iteration = info.best_score_i
             if verbose:
                  sys.stderr.write("Stopping at round %d" % iteration)
             raise StopTraining()
     return callback

def train(param, num_boost_round, callbacks):
     ....
     for i in range(num_boost_round):
          bst.update()
          try:
              for callback  in callbacks:
                  callback(i, bst, evaluation_results)
          except StopTraining:
              break

# call training
bst = train(param, num_boost_round,  callbacks=early_stop_maximize(3, 'test-auc'));

TODO List

  • Add the callback API to training and cv API.
    • This include python/R/Julia/JVM
  • Add a callback function module to xgboost
    • We will only accept improvements to callbacks in the future, and being more careful about training API change.
    • Add callbacks to support early_stop, logging, best_model_save
  • Use the callbacks to keep backward-compatibility
    • For example, when early_stop_rounds is detected, add early_stop_maximize to the callback list in the beginning of function
    • Mark the newly added arguments as deprecated, and give a deprecation warning to ask user to use callback API
    • We will consider remove some of the not so import arguments after two major release.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions