-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Closed
Description
System Info
- huggingface_hub version: 0.24.7
- Platform: Linux-6.1.85+-x86_64-with-glibc2.35
- Python version: 3.10.12
- Running in iPython ?: No
- Running in notebook ?: No
- Running in Google Colab ?: Yes
- Token path ?: /root/.cache/huggingface/token
- Has saved token ?: False
- Configured git credential helpers:
- FastAI: 2.7.18
- Tensorflow: 2.17.0
- Torch: 2.5.0+cu121
- Jinja2: 3.1.4
- Graphviz: 0.20.3
- keras: 3.4.1
- Pydot: 3.0.2
- Pillow: 10.4.0
- hf_transfer: N/A
- gradio: N/A
- tensorboard: N/A
- numpy: 1.26.4
- pydantic: 2.9.2
- aiohttp: 3.10.10
- ENDPOINT: https://huggingface.co/
- HF_HUB_CACHE: /root/.cache/huggingface/hub
- HF_ASSETS_CACHE: /root/.cache/huggingface/assets
- HF_TOKEN_PATH: /root/.cache/huggingface/token
- HF_HUB_OFFLINE: False
- HF_HUB_DISABLE_TELEMETRY: False
- HF_HUB_DISABLE_PROGRESS_BARS: None
- HF_HUB_DISABLE_SYMLINKS_WARNING: False
- HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
- HF_HUB_DISABLE_IMPLICIT_TOKEN: False
- HF_HUB_ENABLE_HF_TRANSFER: False
- HF_HUB_ETAG_TIMEOUT: 10
- HF_HUB_DOWNLOAD_TIMEOUT: 10
Who can help?
Reproduction
- Open @sanchit-gandhi notebook on benchmarking Flax and Torch on Whisper on Colab T4 at this address
- Change the model_id to "openai/whisper-large-v3"
- run the whole cells
Expected behavior
Even with setting the dtype of the model to jnp.float16 , model on Flax will consume about 14 GB of RAM which is much higher than Torch Version which is about 6 GB
I tried the original whisper-tiny sample and the results were reproducable about 500% increase in inference speed with flax
but by changing the model architecture to large v3 , we will face Out Of Memory Error .
One observed thing in the source code at here is that i witness every dtype of the model components are set to jnp.float32 .
two possible things i came up with are :
- even with setting FlaxWhisperForConditionalGeneration dtype to jnp.float16 , the model components will still use jnp.float32
- i think when connverting the model from_pt to jax dtypes , the model weights in torch architecture will still stay in the memory
- and i may have mistaken some scripts here , if so , mention it please
any idea on this issue would be helpful . thanks guys