Skip to content

xpu device is not used running pipeline(device_map="auto") #31922

Closed
huggingface/accelerate
#3275
@dvrogozh

Description

Found on this code versions: 5258501, huggingface/accelerate@12a007d, pytorch/pytorch@3477ee3. This is an issue with XPU support in stock pytorch (i.e. without using IPEX).

HF model pipelines with device_map="auto" (or device_map="sequential") does not actually run on XPU even if they can fit the device memory. I spotted that trying to run LLAMA 3 models:

Example script:

import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]
terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipeline(
    messages,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
print(outputs[0]["generated_text"][-1])

Workarounds and findings:

  • If model fits device memory, then changing device_map="auto" to device_map="xpu" will allow model to run (that's easier to check on 8B model)
  • Model starts to also work (but see a note below) if you add max_memory to the model kwargs:
model_kwargs={"torch_dtype": torch.bfloat16, "max_memory": {0: 5.0e+10}}, device_map="auto",
...
  File "/home/gta/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 118, in __getitem__
    return self.dataset[f"{self.prefix}{key}"]
  File "/home/gta/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 171, in __getitem__
    tensor = f.get_tensor(weight_info.get("weight_name", key))
  File "/home/gta/git/pytorch/pytorch/torch/cuda/__init__.py", line 305, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 @sywangyi @yao-matrix

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions