Skip to content

FlaxWhisperForConditionalGeneration Out Of Memory Error #34668

@heydaari

Description

@heydaari

System Info

  • huggingface_hub version: 0.24.7
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Running in iPython ?: No
  • Running in notebook ?: No
  • Running in Google Colab ?: Yes
  • Token path ?: /root/.cache/huggingface/token
  • Has saved token ?: False
  • Configured git credential helpers:
  • FastAI: 2.7.18
  • Tensorflow: 2.17.0
  • Torch: 2.5.0+cu121
  • Jinja2: 3.1.4
  • Graphviz: 0.20.3
  • keras: 3.4.1
  • Pydot: 3.0.2
  • Pillow: 10.4.0
  • hf_transfer: N/A
  • gradio: N/A
  • tensorboard: N/A
  • numpy: 1.26.4
  • pydantic: 2.9.2
  • aiohttp: 3.10.10
  • ENDPOINT: https://huggingface.co/
  • HF_HUB_CACHE: /root/.cache/huggingface/hub
  • HF_ASSETS_CACHE: /root/.cache/huggingface/assets
  • HF_TOKEN_PATH: /root/.cache/huggingface/token
  • HF_HUB_OFFLINE: False
  • HF_HUB_DISABLE_TELEMETRY: False
  • HF_HUB_DISABLE_PROGRESS_BARS: None
  • HF_HUB_DISABLE_SYMLINKS_WARNING: False
  • HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
  • HF_HUB_DISABLE_IMPLICIT_TOKEN: False
  • HF_HUB_ENABLE_HF_TRANSFER: False
  • HF_HUB_ETAG_TIMEOUT: 10
  • HF_HUB_DOWNLOAD_TIMEOUT: 10

Who can help?

@sanchit-gandhi

Reproduction

  1. Open @sanchit-gandhi notebook on benchmarking Flax and Torch on Whisper on Colab T4 at this address
  2. Change the model_id to "openai/whisper-large-v3"
  3. run the whole cells

Expected behavior

Even with setting the dtype of the model to jnp.float16 , model on Flax will consume about 14 GB of RAM which is much higher than Torch Version which is about 6 GB

I tried the original whisper-tiny sample and the results were reproducable about 500% increase in inference speed with flax
but by changing the model architecture to large v3 , we will face Out Of Memory Error .

One observed thing in the source code at here is that i witness every dtype of the model components are set to jnp.float32 .
two possible things i came up with are :

  1. even with setting FlaxWhisperForConditionalGeneration dtype to jnp.float16 , the model components will still use jnp.float32
  2. i think when connverting the model from_pt to jax dtypes , the model weights in torch architecture will still stay in the memory
  3. and i may have mistaken some scripts here , if so , mention it please

any idea on this issue would be helpful . thanks guys

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions