Skip to content

CausalLM loss function throws runtime error in multi-gpu setup #35086

Closed
@xspirus

Description

@xspirus

System Info

  • transformers version: 4.46.3
  • Platform: Linux-6.1.0-28-cloud-amd64-x86_64-with-glibc2.36
  • Python version: 3.12.7
  • Huggingface_hub version: 0.26.3
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: parallel (device_map="auto")
  • Using GPU in script?: yes
  • GPU type: NVIDIA L4

Who can help?

@ArthurZucker @muellerzr

While trying to train Qwen/Qwen2.5-7B-Instruct or meta-llama/Llama-3.1-8B-Instruct using the SFTTrainer of the trl library, on a machine with 4 L4 GPUs, during the forward pass, when the loss is about to be calculated, the following error occurs:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/xspirus/sample-project/sample.py", line 179, in <module>
    sft()
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/src/llms/sft/__main__.py", line 168, in sft
    trainer.train()
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3633, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/peft/peft_model.py", line 1644, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1183, in forward
    loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 46, in ForCausalLMLoss
    loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 28, in fixed_cross_entropy
    loss = loss / num_items_in_batch
           ~~~~~^~~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!

This stems from the changes in version 4.46.0 where #34191 was introduced. What I suspect is going on here, is that because num_items_in_batch is calculated based on the batch sampler of the inputs (and the inputs are probably placed on cuda:0) and the loss which is calculated based on the outputs of the model (which are placed in the last GPU cuda:3), thus creating the error.

I am not sure if the fix is as simple as the following piece of code:

def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch.to(loss.device)
    return loss

where the above fix lies in the num_items_in_batch.to(loss.device).

I am creating this issue so that you can more confidently solve this issue, since you are more familiar with this part of the codebase.

NOTE: the error does not occur on transformers v4.45.2.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import tempfile
from pathlib import Path
from typing import cast

import click
import torch
from datasets import load_dataset
from peft.mapping import get_peft_model
from peft.tuners.lora import LoraConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizerFast
from transformers.training_args import IntervalStrategy
from trl import SFTConfig
from trl import SFTTrainer


@click.command(name="sft", help="Train LLM with SFT technique.")
@click.option("--model", "model_name", type=str, required=True, help="Base model name.")
@click.option(
    "--output",
    type=click.Path(exists=False, dir_okay=True, writable=True, resolve_path=True, path_type=Path),
    required=True,
    help="Output directory.",
)
@click.option(
    "--use-lora",
    is_flag=True,
    help="Whether to use LoRA or not.",
)
def sft(model_name: str, output: Path, use_lora: bool = False):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        raise ValueError("Tokenizer must be fast tokenizer.")
    if not tokenizer.pad_token:
        if tokenizer.unk_token:
            tokenizer.pad_token = tokenizer.unk_token
        else:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    original_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "right"

    data = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
    data = data.shuffle()

    model = cast(
        PreTrainedModel,
        AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            device_map="auto",
            attn_implementation="flash_attention_2",
            trust_remote_code=True,
            use_cache=False,
        ),
    )

    if use_lora:
        lora_config = LoraConfig(
            task_type="CAUSAL_LM",
            r=32,
            lora_alpha=64,
            lora_dropout=0.05,
            target_modules="all-linear",
        )
        model = get_peft_model(model, lora_config)

    with tempfile.TemporaryDirectory() as tmp_dir:
        training_args = SFTConfig(
            output_dir=tmp_dir,
            bf16=model.dtype == torch.bfloat16,
            fp16=model.dtype == torch.float16,
            learning_rate=2e-4,
            neftune_noise_alpha=5,
            num_train_epochs=3,
            packing=False,
            per_device_eval_batch_size=4,
            per_device_train_batch_size=4,
            save_strategy=IntervalStrategy.NO,
            warmup_ratio=0.03,
            weight_decay=1e-3,
        )

        trainer = SFTTrainer(
            model=model,
            processing_class=tokenizer,
            args=training_args,
            train_dataset=data,
        )
        trainer.train()

    tokenizer.padding_side = original_padding_side
    tokenizer.save_pretrained(str(output))
    model.save_pretrained(str(output))


if __name__ == "__main__":
    sft()

and requirements

accelerate==1.1.1
aiohappyeyeballs==2.4.4
aiohttp==3.11.9
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.2.post1
attrs==24.2.0
bitsandbytes==0.44.1
certifi==2024.8.30
charset-normalizer==3.4.0
click==8.1.7
colorama==0.4.6 ; platform_system == 'Windows'
datasets==3.1.0
dill==0.3.8
distro==1.9.0
einops==0.8.0
filelock==3.16.1
flash-attn==2.7.0.post2
frozenlist==1.5.0
fsspec==2024.9.0
greenlet==3.1.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
h11==0.14.0
httpcore==1.0.7
httpx==0.28.0
huggingface-hub==0.26.3
idna==3.10
jinja2==3.1.4
jiter==0.8.0
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
langchain==0.3.9
langchain-core==0.3.21
langchain-openai==0.2.11
langchain-text-splitters==0.3.2
langsmith==0.1.147
markdown-it-py==3.0.0
markupsafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
openai==1.56.2
orjson==3.10.12 ; platform_python_implementation != 'PyPy'
packaging==24.2
pandas==2.2.3
peft==0.13.2
propcache==0.2.1
psutil==6.1.0
pyarrow==18.1.0
pydantic==2.10.3
pydantic-core==2.27.1
pygments==2.18.0
python-dateutil==2.9.0.post0
pytz==2024.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
rich==13.9.4
ruff==0.8.1
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
setuptools==75.6.0
six==1.16.0
sniffio==1.3.1
sqlalchemy==2.0.36
sympy==1.13.1
tenacity==9.0.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.3
torch==2.5.1
tqdm==4.67.1
transformers==4.46.3
triton==3.1.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
trl==0.12.1
typing-extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
xxhash==3.5.0
yarl==1.18.3

Expected behavior

Training should occur normally like in v.4.45.2.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions