Description
System Info
transformers
version: 4.46.3- Platform: Linux-6.1.0-28-cloud-amd64-x86_64-with-glibc2.36
- Python version: 3.12.7
- Huggingface_hub version: 0.26.3
- Safetensors version: 0.4.5
- Accelerate version: 1.1.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: parallel (
device_map="auto"
) - Using GPU in script?: yes
- GPU type: NVIDIA L4
Who can help?
While trying to train Qwen/Qwen2.5-7B-Instruct
or meta-llama/Llama-3.1-8B-Instruct
using the SFTTrainer
of the trl
library, on a machine with 4 L4 GPUs, during the forward pass, when the loss is about to be calculated, the following error occurs:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/xspirus/sample-project/sample.py", line 179, in <module>
sft()
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1157, in __call__
return self.main(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1078, in main
rv = self.invoke(ctx)
^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/src/llms/sft/__main__.py", line 168, in sft
trainer.train()
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3633, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 823, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 811, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/peft/peft_model.py", line 1644, in forward
return self.base_model(
^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
return self.model.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1183, in forward
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 46, in ForCausalLMLoss
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xspirus/sample-project/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 28, in fixed_cross_entropy
loss = loss / num_items_in_batch
~~~~~^~~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!
This stems from the changes in version 4.46.0
where #34191 was introduced. What I suspect is going on here, is that because num_items_in_batch
is calculated based on the batch sampler of the inputs (and the inputs are probably placed on cuda:0
) and the loss
which is calculated based on the outputs of the model (which are placed in the last GPU cuda:3
), thus creating the error.
I am not sure if the fix is as simple as the following piece of code:
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch.to(loss.device)
return loss
where the above fix lies in the num_items_in_batch.to(loss.device)
.
I am creating this issue so that you can more confidently solve this issue, since you are more familiar with this part of the codebase.
NOTE: the error does not occur on transformers v4.45.2
.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
import tempfile
from pathlib import Path
from typing import cast
import click
import torch
from datasets import load_dataset
from peft.mapping import get_peft_model
from peft.tuners.lora import LoraConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizerFast
from transformers.training_args import IntervalStrategy
from trl import SFTConfig
from trl import SFTTrainer
@click.command(name="sft", help="Train LLM with SFT technique.")
@click.option("--model", "model_name", type=str, required=True, help="Base model name.")
@click.option(
"--output",
type=click.Path(exists=False, dir_okay=True, writable=True, resolve_path=True, path_type=Path),
required=True,
help="Output directory.",
)
@click.option(
"--use-lora",
is_flag=True,
help="Whether to use LoRA or not.",
)
def sft(model_name: str, output: Path, use_lora: bool = False):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError("Tokenizer must be fast tokenizer.")
if not tokenizer.pad_token:
if tokenizer.unk_token:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
original_padding_side = tokenizer.padding_side
tokenizer.padding_side = "right"
data = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
data = data.shuffle()
model = cast(
PreTrainedModel,
AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
trust_remote_code=True,
use_cache=False,
),
)
if use_lora:
lora_config = LoraConfig(
task_type="CAUSAL_LM",
r=32,
lora_alpha=64,
lora_dropout=0.05,
target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
bf16=model.dtype == torch.bfloat16,
fp16=model.dtype == torch.float16,
learning_rate=2e-4,
neftune_noise_alpha=5,
num_train_epochs=3,
packing=False,
per_device_eval_batch_size=4,
per_device_train_batch_size=4,
save_strategy=IntervalStrategy.NO,
warmup_ratio=0.03,
weight_decay=1e-3,
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=data,
)
trainer.train()
tokenizer.padding_side = original_padding_side
tokenizer.save_pretrained(str(output))
model.save_pretrained(str(output))
if __name__ == "__main__":
sft()
and requirements
accelerate==1.1.1
aiohappyeyeballs==2.4.4
aiohttp==3.11.9
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.2.post1
attrs==24.2.0
bitsandbytes==0.44.1
certifi==2024.8.30
charset-normalizer==3.4.0
click==8.1.7
colorama==0.4.6 ; platform_system == 'Windows'
datasets==3.1.0
dill==0.3.8
distro==1.9.0
einops==0.8.0
filelock==3.16.1
flash-attn==2.7.0.post2
frozenlist==1.5.0
fsspec==2024.9.0
greenlet==3.1.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
h11==0.14.0
httpcore==1.0.7
httpx==0.28.0
huggingface-hub==0.26.3
idna==3.10
jinja2==3.1.4
jiter==0.8.0
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
langchain==0.3.9
langchain-core==0.3.21
langchain-openai==0.2.11
langchain-text-splitters==0.3.2
langsmith==0.1.147
markdown-it-py==3.0.0
markupsafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
openai==1.56.2
orjson==3.10.12 ; platform_python_implementation != 'PyPy'
packaging==24.2
pandas==2.2.3
peft==0.13.2
propcache==0.2.1
psutil==6.1.0
pyarrow==18.1.0
pydantic==2.10.3
pydantic-core==2.27.1
pygments==2.18.0
python-dateutil==2.9.0.post0
pytz==2024.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
rich==13.9.4
ruff==0.8.1
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
setuptools==75.6.0
six==1.16.0
sniffio==1.3.1
sqlalchemy==2.0.36
sympy==1.13.1
tenacity==9.0.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.3
torch==2.5.1
tqdm==4.67.1
transformers==4.46.3
triton==3.1.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
trl==0.12.1
typing-extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
xxhash==3.5.0
yarl==1.18.3
Expected behavior
Training should occur normally like in v.4.45.2.