Skip to content

[Idefics 3] Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1 #35031

Closed
@shyshin

Description

System Info

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-4.18.0-553.27.1.el8_10.x86_64-x86_64-with-glibc2.28
  • Python version: 3.11.6
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.3
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.6.1 (gpu)
  • Jax version: 0.4.4
  • JaxLib version: 0.4.3
  • Using distributed or parallel set-up in script?: Yes
  • Using GPU in script?: Yes, 4 GPUs
  • GPU type: NVIDIA H100

Who can help?

@qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I was trying to fine-tune a SmolVLM-Instruct model using the fine-tuning code mentioned for SmolVLM-Base here on a 4 H100 GPU setup.

  1. Change model_id to "HuggingFaceTB/SmolVLM-Instruct"
  2. Use PubMedVision instead of VQAv2 dataset.
  3. Change collate_fn to accomodate PubMedVision
  4. Run the script as is.

I encountered a [RuntimeError] Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1 on running trainer.train().

TL;DR

File ~/.local/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:904, in Idefics3Model.inputs_merger(self, input_ids, inputs_embeds, image_hidden_states)
    902 # cast to the dtype of the input_embeds to support quantized models
    903 reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.dtype)
--> 904 new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
    905 return new_inputs_embeds

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1

Actual traceback:

RuntimeError                              Traceback (most recent call last)
Cell In[47], line 1
----> 1 trainer.train()

File ~/.local/lib/python3.11/site-packages/transformers/trainer.py:2163, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2161         hf_hub_utils.enable_progress_bars()
   2162 else:
-> 2163     return inner_training_loop(
   2164         args=args,
   2165         resume_from_checkpoint=resume_from_checkpoint,
   2166         trial=trial,
   2167         ignore_keys_for_eval=ignore_keys_for_eval,
   2168     )

File ~/.local/lib/python3.11/site-packages/transformers/trainer.py:2521, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2515 context = (
   2516     functools.partial(self.accelerator.no_sync, model=model)
   2517     if i != len(batch_samples) - 1
   2518     else contextlib.nullcontext
   2519 )
   2520 with context():
-> 2521     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2523 if (
   2524     args.logging_nan_inf_filter
   2525     and not is_torch_xla_available()
   2526     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2527 ):
   2528     # if loss is nan or inf simply add the average of previous logged losses
   2529     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/.local/lib/python3.11/site-packages/transformers/trainer.py:3651, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3648     return loss_mb.reduce_mean().detach().to(self.args.device)
   3650 with self.compute_loss_context_manager():
-> 3651     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3653 del inputs
   3654 if (
   3655     self.args.torch_empty_cache_steps is not None
   3656     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3657 ):

File ~/.local/lib/python3.11/site-packages/transformers/trainer.py:3705, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3703         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3704     inputs = {**inputs, **loss_kwargs}
-> 3705 outputs = model(**inputs)
   3706 # Save past state if it exists
   3707 # TODO: this needs to be fixed and made cleaner later.
   3708 if self.args.past_index >= 0:

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.11/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    822 def forward(*args, **kwargs):
--> 823     return model_forward(*args, **kwargs)

File ~/.local/lib/python3.11/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    810 def __call__(self, *args, **kwargs):
--> 811     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File ~/.local/lib/python3.11/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    822 def forward(*args, **kwargs):
--> 823     return model_forward(*args, **kwargs)

File ~/.local/lib/python3.11/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    810 def __call__(self, *args, **kwargs):
--> 811     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File ~/.local/lib/python3.11/site-packages/peft/peft_model.py:812, in PeftModel.forward(self, *args, **kwargs)
    810 with self._enable_peft_forward_hooks(*args, **kwargs):
    811     kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
--> 812     return self.get_base_model()(*args, **kwargs)

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:1196, in Idefics3ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1193 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1195 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1196 outputs = self.model(
   1197     input_ids=input_ids,
   1198     attention_mask=attention_mask,
   1199     position_ids=position_ids,
   1200     past_key_values=past_key_values,
   1201     inputs_embeds=inputs_embeds,
   1202     pixel_values=pixel_values,
   1203     pixel_attention_mask=pixel_attention_mask,
   1204     image_hidden_states=image_hidden_states,
   1205     use_cache=use_cache,
   1206     output_attentions=output_attentions,
   1207     output_hidden_states=output_hidden_states,
   1208     return_dict=return_dict,
   1209 )
   1211 hidden_states = outputs[0]
   1212 logits = self.lm_head(hidden_states)

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /cvmfs/jupyter.hpc.rwth.de/clients/pytorch-ifs-3a-c23/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:1014, in Idefics3Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
   1009     image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
   1011 if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
   1012     # When we generate, we don't want to replace the potential image_token_id that we generated by images
   1013     # that simply don't exist
-> 1014     inputs_embeds = self.inputs_merger(
   1015         input_ids=input_ids,
   1016         inputs_embeds=inputs_embeds,
   1017         image_hidden_states=image_hidden_states,
   1018     )
   1020 outputs = self.text_model(
   1021     inputs_embeds=inputs_embeds,
   1022     attention_mask=attention_mask,
   (...)
   1028     return_dict=return_dict,
   1029 )
   1031 if not return_dict:

File ~/.local/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:904, in Idefics3Model.inputs_merger(self, input_ids, inputs_embeds, image_hidden_states)
    902 # cast to the dtype of the input_embeds to support quantized models
    903 reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.dtype)
--> 904 new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
    905 return new_inputs_embeds

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1

Expected behavior

No Runtime Error.

Potential Fix

I looked up a couple of similar issues like #24410 and came up with this fix:

File ~/.local/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:904

new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(self.device)

This fixed the issue for me. Do let me know if this works.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions