-
Notifications
You must be signed in to change notification settings - Fork 29.1k
Add TimesFM Time Series Forecasting Model #34082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work.
One general question: are there plans to support decoder only training and associated loss functions ?
Great work! I left some minor comments. |
Would this PR add finetuning support for the model? |
Marking @kashif as the HF point of contact for this! |
This reverts commit def36c4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, super super nice! 🤗 Love that you already integrated the new init standard!! Left some final comments but we can merge right afterwards!
Thanks a lot, and sorry again for the delay in the reviews!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, last nits then merging!! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right, LGTM!! Merging, thank you for being patient and bearing with me along the way!! 🤗
Thanks for the very nice addition!
* initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * initial copy from t5 * added config and attention layers * add TimesFMPositionalEmbedding * calcuate scale_factor once * add more configs and TimesFMResidualBlock * fix input_dims * standardize code format with black * remove unneeded modules * TimesFM Model * order of imports * copy from Google official implementation * remove covariate forecasting * Adapting TimesFM to HF format * restructing in progress * adapted to HF convention * timesfm test * the model runs * fixing unit tests * fixing unit tests in progress * add post_init * do not change TimesFMOutput * fixing unit tests * all unit tests passed * remove timesfm_layers * add intermediate_size and initialize with config * initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * add _CHECKPOINT_FOR_DOC * fix comments * Revert "fix comments" This reverts commit 8deeb3e. * add _prepare_4d_attention_mask * we do not have generative model classes * use Cache * return past_key_values * modules initialized with config only * update year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Steven Liu <[email protected]> * add layer_idx to cache * modular timesfm * fix test * unwrap sequential class * fix toctree * remove TimesFmOnnxConfig * fix modular * remove TimesFmStackedDecoder * split qkv layer into individual layers * rename projection layers * use ALL_ATTENTION_FUNCTIONS * is_causal is True * rename config * does not support flash_attn_2 * formatting * fix typo in docsstring * rename inputs * add time series mapping * Update src/transformers/models/olmo2/modeling_olmo2.py * Update src/transformers/models/moonshine/modeling_moonshine.py * use updated arguments * fix class name * add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING * isort * consolidate _preprocess into forward * fix a typo * fix a typo * fix toc * fix modular * remove aaserts * use self.config._attn_implementation * move to _postprocess_output * remove timesfm_get_large_negative_number * use view unstead of multiple unsqueeze * make helpers static methods of the Model * use to_tuple * use to_tuple if not return_dict * remove unused intitialization block as its incorporated in nn.Linear * remove unused num_key_value_groups * use the same convention as the masking method * update modular * do not use unsqueeze * use view instead of unsqueeze * use buffer for inv_timescales * formatting * modular conversion * remove unneeded intialization * add missing docstrings * remove cache * use simple_eager_attention_forward * support tp_plan * support for flex and flash attention masks * Revert "support for flex and flash attention masks" This reverts commit def36c4. * fix device * fix tests on gpu * remove unsued large model test * removed unneeded comments * add example usage * fix style * add import * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <[email protected]> * inherit from LlamaRMSNorm * use can_return_tuple decorator * remvoe return_dict * fix year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <[email protected]> * pretrained does not inherit from GenerationMixin * use model for integration test --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Rajat Sen <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Cyril Vallez <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
* initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * initial copy from t5 * added config and attention layers * add TimesFMPositionalEmbedding * calcuate scale_factor once * add more configs and TimesFMResidualBlock * fix input_dims * standardize code format with black * remove unneeded modules * TimesFM Model * order of imports * copy from Google official implementation * remove covariate forecasting * Adapting TimesFM to HF format * restructing in progress * adapted to HF convention * timesfm test * the model runs * fixing unit tests * fixing unit tests in progress * add post_init * do not change TimesFMOutput * fixing unit tests * all unit tests passed * remove timesfm_layers * add intermediate_size and initialize with config * initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * add _CHECKPOINT_FOR_DOC * fix comments * Revert "fix comments" This reverts commit 8deeb3e. * add _prepare_4d_attention_mask * we do not have generative model classes * use Cache * return past_key_values * modules initialized with config only * update year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Steven Liu <[email protected]> * add layer_idx to cache * modular timesfm * fix test * unwrap sequential class * fix toctree * remove TimesFmOnnxConfig * fix modular * remove TimesFmStackedDecoder * split qkv layer into individual layers * rename projection layers * use ALL_ATTENTION_FUNCTIONS * is_causal is True * rename config * does not support flash_attn_2 * formatting * fix typo in docsstring * rename inputs * add time series mapping * Update src/transformers/models/olmo2/modeling_olmo2.py * Update src/transformers/models/moonshine/modeling_moonshine.py * use updated arguments * fix class name * add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING * isort * consolidate _preprocess into forward * fix a typo * fix a typo * fix toc * fix modular * remove aaserts * use self.config._attn_implementation * move to _postprocess_output * remove timesfm_get_large_negative_number * use view unstead of multiple unsqueeze * make helpers static methods of the Model * use to_tuple * use to_tuple if not return_dict * remove unused intitialization block as its incorporated in nn.Linear * remove unused num_key_value_groups * use the same convention as the masking method * update modular * do not use unsqueeze * use view instead of unsqueeze * use buffer for inv_timescales * formatting * modular conversion * remove unneeded intialization * add missing docstrings * remove cache * use simple_eager_attention_forward * support tp_plan * support for flex and flash attention masks * Revert "support for flex and flash attention masks" This reverts commit def36c4. * fix device * fix tests on gpu * remove unsued large model test * removed unneeded comments * add example usage * fix style * add import * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <[email protected]> * inherit from LlamaRMSNorm * use can_return_tuple decorator * remvoe return_dict * fix year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <[email protected]> * pretrained does not inherit from GenerationMixin * use model for integration test --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Rajat Sen <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Cyril Vallez <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
What does this PR do?
Fixes #33745
This PR adds a new model, TimesFM, to HuggingFace. TimesFM is a time series forecasting model based on the transformer architecture. It was proposed in A decoder-only foundation model for time-series forecasting.
Code Example
TODOs