Skip to content

logged loss is not correct with gradient accumulation #35204

@Jack47

Description

@Jack47

System Info

transformer v4.46.3

Who can help?

@muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

currently logged loss doesn't divide gradient accumulation steps, so it will be bigger than expected:
https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2521-L2536

                    with context():
                        tr_loss_step = self.training_step(model, inputs, num_items_in_batch)

                    if (
                        args.logging_nan_inf_filter
                        and not is_torch_xla_available()
                        and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                    ):
                        # if loss is nan or inf simply add the average of previous logged losses
                        tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
                    else:
                        if tr_loss.device != tr_loss_step.device:
                            raise ValueError(
                                f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
                            )
                        tr_loss = tr_loss + tr_loss_step

Expected behavior

how to fix

diff --git a/trainer.py b/trainer.py
index 1b9b80f..043c6c9 100755
--- a/trainer.py
+++ b/trainer.py
@@ -2546,7 +2546,7 @@ class Trainer:
                         self.state.global_step += 1
                         self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
                         self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-                        self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
+                        self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, num_batches)
                     else:
                         self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
 
@@ -2571,7 +2571,7 @@ class Trainer:
                 self.control.should_training_stop = True
 
             self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-            self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
+            self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, self.args.gradient_accumulation_steps)
 
             if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
                 if is_torch_xla_available():
@@ -2976,7 +2976,7 @@ class Trainer:
                 ) from exc
         return metrics
 
-    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
+    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, ga_steps):
         if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
             if is_torch_xla_available():
                 xm.mark_step()
@@ -2990,7 +2990,7 @@ class Trainer:
             # reset tr_loss to zero
             tr_loss -= tr_loss
 
-            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+            logs["loss"] = round(tr_loss_scalar/ ga_steps / (self.state.global_step - self._globalstep_last_logged), 4)
             if grad_norm is not None:
                 logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions