Skip to content

Conversation

@qubvel
Copy link
Contributor

@qubvel qubvel commented Dec 6, 2024

What does this PR do?

Add a common slow test to check if a model can be exported with no issues using torch.export.export

  1. Add an optional test, to enable it please set test_torch_exportable = True flag for model-specific test.
  2. Enable test for vision and video models
  3. Fix most of the vision models

The main fixes include:

  • Use a compile-compatible LRU cache for models.
  • Avoid modifying model parameters in the forward pass (e.g. self.param = self.param + x).
  • Avoid modifying leaf in-place tensors created in the forward pass.
  • Avoid creating tensors with requires_grad=True in the forward pass.

Testing is not complete, there might be code paths that can't be exported. I did additional testing with specific checkpoints. In most cases, we are safe. The only two situations I found where tests pass but checkpoint export does not pass are:

  • beit (fixed)
  • zoedepth (not fixed)

Results

✅ - can be exported with torch.export.export
🔵 - export fixed in this PR
❌ - can't be exported

Vision models

  • 🔵 beit
  • 🔵 bit
  • 🔵 conditional_detr
  • ✅ convnext
  • ✅ convnextv2
  • ✅ cvt
  • ✅ dab_detr
  • 🔵 deformable_detr
  • ✅ deit
  • ✅ depth_anything
  • ✅ depth_pro
  • 🔵 detr
  • ✅ dinat
  • ✅ dinov2
  • ✅ dinov2_with_registers
  • ✅ dit
  • ✅ dpt
  • ✅ efficientnet
  • 🔵 focalnet
  • ✅ glpn
  • ✅ hiera
  • ✅ ijepa
  • 🔵 imagegpt
  • ❌ levit (low usage, won't fix)
  • ✅ mask2former
  • 🔵 maskformer
  • ✅ mobilenet_v1
  • ✅ mobilenet_v2
  • ✅ mobilevit
  • ✅ mobilevitv2
  • ✅ poolformer
  • ✅ pvt
  • ✅ pvt_v2
  • ✅ regnet
  • ✅ resnet
  • ✅ rt_detr
  • 🔵 rt_detr_v2
  • ✅ segformer
  • 🔵 seggpt
  • ❌ superpoint (data-dependent expression)
  • ✅ swiftformer
  • ✅ swin
  • ✅ swinv2
  • 🔵 swin2sr
  • ✅ table_transformer
  • ✅ textnet
  • ✅ upernet
  • ✅ vit
  • ✅ vitdet
  • ✅ vit_mae
  • ✅ vitmatte
  • ✅ vit_msn
  • ✅ vitpose
  • ✅ vitpose_backbone
  • ✅ yolos
  • ❌ zoedept (data-dependent expression, test config pass but checkpoint not)

Video models

  • ✅ timesformer
  • ✅ vivit
  • ✅ videomae

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qubvel qubvel added the Vision label Dec 6, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qubvel
Copy link
Contributor Author

qubvel commented Dec 11, 2024

@guangy10 please have a look if you have bandwidth! Do you have anything in mind that should be added to the common test on the side of the transformers to ensure the model is exportable and executorch compatible?

@qubvel qubvel requested a review from ydshieh December 17, 2024 22:32
@qubvel qubvel added run-slow torch export Issues and PRs related to torch.export compatibility labels Dec 17, 2024
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @qubvel for working on this!

I haven't check the changes in models but left a few comments in test_modeling_common.py

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left a few tiny comments.

Could we also trigger slow CI for the modified models?

You can rebase on main to use the new way to trigger slow CI.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure run a slow CI 🙏

Thanks for the work.

@qubvel
Copy link
Contributor Author

qubvel commented Dec 18, 2024

Ok, sure, what is a new way to trigger slow tests?

Thanks for review 🤗

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 18, 2024

Copy link
Contributor

@guangy10 guangy10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel Great work and thank you for standardizing the test for exportability! Does this PR cover all vision existing vision models in transformers? Is there a plan to set up the same standard for audio models (in separate PRs)?

test_mismatched_shapes = True
test_missing_keys = True
test_model_parallel = False
test_torch_exportable = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If vast majority vision models are exportable, should we strategically turn this flag to True? Typically new models are more popular and important than old models, the biggest benefit of reversing the default to True is to "softly enforce" new models to be exportable, naturally growing this path to be the default and hard enforced over time. When I say softly enforce, I mean the new model can still have the option to disable export test if confirmed the failure is not obvious to fix and file a github issue in the backlog. But I do think most of the failures would be common and easy fixable given the work you did in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Common tests are applied to all modalities, so for now, I suppose we will set it to False before updating for other models.

Copy link
Contributor

@guangy10 guangy10 Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including text models right? @qubvel Could you guide me how I can leverage this common test to cover more test models like this: https://github.com/search?q=repo%3Ahuggingface%2Ftransformers%20test_export_static_cache&type=code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It' s fine to have to as we want to reach 100% exportability no?

for key in eager_outputs:
is_tested = is_tested or recursively_check(eager_outputs[key], exported_outputs[key])
return is_tested
return is_tested
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nit for debuggability: Is silent return from here is expected? I think we should explicitly report/catch an error instead because it bypasses all checks due to output type mismatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the main targets are covered, in case no matching types is_tested will be False and the error will be raised below.

model,
args=(),
kwargs=inputs_dict,
strict=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmagogsfm Oh I think I recommended strict=True somewhere. The reason is to ensure the safety. If some python code block is unsupported by torchdynamo and got traced out in non-strict mode, instead of proceed silently defer seeing inconsistent results in the exported graph, a better process is to first surface the error up with strict mode, let the the model author or reviewers to see the source and confirm if it's safe and then turn the strict flag to False to unblock. cc: @ydshieh @qubvel

with tempfile.TemporaryDirectory() as tmpdirname:
save_path = os.path.join(tmpdirname, "exported_model.pt2")
torch.export.save(exported_model, save_path)
exported_model = torch.export.load(save_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel You may want to add a check to ensure the loaded exported model is identical, to avoid any weird issues due to bugs in (de)serialization as you are checking the export outputs against eager using the loaded artifact.

In a workflow (e.g. to ExecuTorch for inference on-device) where exported model is just an intermediate representation, there is no need to save the exported program to a local fs first. We don't want any bug in (de)serialization of the exported IR affect the testability of the downstream workflow, so it's better to first compare the export and eager outputs, then test the (de)serialization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, basically, there is almost nothing we can do on our side in case of a deserialization error (just report to a torch team). So, I suppose we can remove this entirely, WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can remove this entirely

raise ValueError(f"Unsupported parallel style value: {style}")


def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel Do you mind elaborate a bit more how this decorator helps with export?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, see below

def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
"""
LRU cache decorator from standard functools library, but with a workaround to disable
caching when torchdynamo is compiling. Expected to work with class methods.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what particular caching is disabled when torchdynamo is tracing?

Copy link
Contributor Author

@qubvel qubvel Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We turning off standard lru_cache when torchdynamo is compiling.

We had this decorator previously only for RT-DETR, I just moved it to make it common:

RT-DETR (afair) caches anchors to avoid it's creation for the same image size.
OmDet-Turbo (zero-shot-object-detection, out of this PR) cashes text label embeddings in order to avoid recomputing them once again in case the same label is passed.

Not that much speedup coming from these optimizations tbh, so I suppose it's fine to turn it off to enable compile/export.


default_config, default_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config = config or default_config
inputs_dict = inputs_dict or default_inputs_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are required parameters to be passed to the forward()? I think we should standardize the signature of forward so that we can simplify and standardize developing an ExecuTorch runtime for all vision models? Like what we did for text model in transformers.integrations.executorch where only input_ids and current cache_position are required as the model is always exported with the static_cache. We don't have to maintain backwards compatibility (BC) for all arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be different for different vision models, but most of them should have only pixel_values as the required parameter.

Copy link
Contributor

@guangy10 guangy10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! 🚀 🚀

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing one big thing: documentation about supported models for export!
otherwise 🚀 ! Kudos

test_mismatched_shapes = True
test_missing_keys = True
test_model_parallel = False
test_torch_exportable = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It' s fine to have to as we want to reach 100% exportability no?

@guangy10
Copy link
Contributor

guangy10 commented Feb 4, 2025

@qubvel Let me know the timeline. I can start working on e2e enablement by connecting these models to optimum-executorch once this PR merged.

@qubvel
Copy link
Contributor Author

qubvel commented Feb 4, 2025

Hey @guangy10, going to finish this week!

@guangy10
Copy link
Contributor

Hey @guangy10, going to finish this week!

@qubvel Just a friendly reminder to merge this PR. Let mw know how I can help if there is any blocker.

@chrsmcgrr
Copy link
Contributor

Nice work in bringing the torch.export coverage up.

Are the export tests executed with torch==2.6.0?

@qubvel
Copy link
Contributor Author

qubvel commented Feb 11, 2025

Hey @chrsmcgrr, yes we have 2.6.0 in CI. I have also run these tests with torch 2.5.0 locally

@qubvel
Copy link
Contributor Author

qubvel commented Feb 11, 2025

run-slow: rt_detr_v2, dab_detr, beit, conditional_detr, deformable_detr, detr, swin2sr

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/beit', 'models/conditional_detr', 'models/dab_detr', 'models/deformable_detr', 'models/detr', 'models/rt_detr_v2', 'models/swin2sr']
quantizations: [] ...

@qubvel
Copy link
Contributor Author

qubvel commented Feb 11, 2025

Test failures are unrelated (fixed in #35654).

Merging this PR to unblock @guangy10 work. I will work on docs in the follow-up PR asap. cc @stevhliu in case you have suggestions on how to organize it better.

@qubvel qubvel merged commit f42d46c into huggingface:main Feb 11, 2025
25 of 26 checks passed
@stevhliu
Copy link
Member

Depending on when the new docs ship (hopefully this week or the next), you can add them here I think :)

If it takes longer, then feel free to add them here and I can move them later!

@guangy10
Copy link
Contributor

Depending on when the new docs ship (hopefully this week or the next), you can add them here I think :)

If it takes longer, then feel free to add them here and I can move them later!

Feel free to subscribe me on the new PR for the doc fix.

@guangy10
Copy link
Contributor

@qubvel While I'm working on lowering these vision models e2e to ExecuTorch in optimum, would you like to start expanding the export coverage over audio models as well? I’d really appreciate the effort!

@qubvel
Copy link
Contributor Author

qubvel commented Feb 12, 2025

@guangy10 I'm not sure if I will have bandwidth in the coming few weeks, but maybe someone from the audio team can have a look
cc @eustlb

@guangy10
Copy link
Contributor

@qubvel Thanks for looping in the audio team!

👋 @eustlb, nice to e-meet you! I'm from PyTorch team at Meta. We've been collaborating with 🤗 to expand torch.export coverage on transformer models since last year (FYI there is a parallel efforts focusing on torch.compile coverage). So far we have covered text and vision models, but not on audio models yet. It seems like your team is the right one to talk. I'd happy to share more context in slack.

@eustlb
Copy link
Contributor

eustlb commented Feb 13, 2025

Hey @guangy10, nice to e-meet you too! It would be a pleasure to help 🤗

sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
…gface#35124)

* Add is_torch_greater_or_equal test decorator

* Add common test for torch.export

* Fix bit

* Fix focalnet

* Fix imagegpt

* Fix seggpt

* Fix swin2sr

* Enable torch.export test for vision models

* Enable test for video models

* Remove json

* Enable for hiera

* Enable for ijepa

* Fix detr

* Fic conditional_detr

* Fix maskformer

* Enable test maskformer

* Fix test for deformable detr

* Fix custom kernels for export in rt-detr and deformable-detr

* Enable test for all DPT

* Remove custom test for deformable detr

* Simplify test to use only kwargs for export

* Add comment

* Move compile_compatible_method_lru_cache to utils

* Fix beit export

* Fix deformable detr

* Fix copies data2vec<->beit

* Fix typos, update test to work with dict

* Add seed to the test

* Enable test for vit_mae

* Fix beit tests

* [run-slow] beit, bit, conditional_detr, data2vec, deformable_detr, detr, focalnet, imagegpt, maskformer, rt_detr, seggpt, swin2sr

* Add vitpose test

* Add textnet test

* Add dinov2 with registers

* Update tests/test_modeling_common.py

* Switch to torch.testing.assert_close

* Fix masformer

* Remove save-load from test

* Add dab_detr

* Add depth_pro

* Fix and test RT-DETRv2

* Fix dab_detr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-slow torch export Issues and PRs related to torch.export compatibility Vision

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants