Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring AssistedCandidateGenerator for Improved Modularity and Reusability #35009

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

keyboardAnt
Copy link
Contributor

@keyboardAnt keyboardAnt commented Nov 28, 2024

What does this PR do?

This PR refactors the AssistedCandidateGenerator and AssistedCandidateGeneratorDifferentTokenizers classes to improve code readability, maintainability, and reusability. By breaking down large methods into smaller helper functions and reusing common logic, we enhance the modularity of these classes. This refactoring lays a cleaner foundation for upcoming features like UniversalSpeculativeDecodingGenerator without introducing any new functionality.

Background

While working on Universal Speculative Decoding (#34760), which introduces the UniversalSpeculativeDecodingGenerator, we identified opportunities to refactor existing code. The goal is to reuse core logic across different candidate generators and simplify the integration of new features that enable speculative decoding across models with different tokenizers.

By submitting this refactoring as a separate PR, we aim to:

  • Streamline the review process: Focus on structural improvements without the complexity of new features.
  • Accelerate availability: Make these improvements accessible to the community sooner.
  • Facilitate future development: Provide a robust base for upcoming enhancements in the generation capabilities.

This refactor is a collaboration with @jmamou, who has already reviewed it (keyboardAnt#1).

Key Changes

1. Code Restructuring

  • Decomposed Large Methods: Broke down the get_candidates methods in both classes into smaller, focused helper functions.
    • Example helper functions:
      • _calculate_new_tokens
      • _update_past_and_masks
      • _prepare_generation_args
      • _generate_candidates
  • Simplified Initialization: Streamlined the __init__ methods to remove redundancy and enhance clarity.

2. Improved Reusability

  • Modularized Operations: Encapsulated unique functionalities in dedicated methods within each class.
    • For AssistedCandidateGeneratorDifferentTokenizers, methods like _prepare_assistant_input_ids and _process_assistant_outputs handle tokenizer-specific logic.
  • Reused Common Logic: Leveraged inheritance and method overriding to share code between classes and reduce duplication.

3. Enhanced Readability

  • Clear Naming Conventions: Renamed variables and methods for better clarity and understanding.
  • Consistent Formatting: Applied uniform code formatting and style guidelines across the classes.
  • Documentation: Added docstrings and comments to explain the purpose and functionality of methods.

Motivation

This refactoring is motivated by the need to:

  • Reduce Technical Debt: Clean up the codebase to make it more maintainable.
  • Facilitate Future Features: Prepare the existing classes for integration with upcoming speculative decoding functionalities.
  • Encourage Community Contributions: Make the code more approachable for contributors by improving readability and structure.

By isolating these changes, we enable reviewers to focus solely on the structural improvements without the added complexity of new features. This approach helps maintain a high code quality standard and simplifies the review and merging process.

Dependencies

  • No New Dependencies: This PR does not introduce any new dependencies. It solely refactors existing code.

Before Submitting

Who Can Review?

The following reviewers are well-suited to review this PR: @gante, @ArthurZucker


This PR aims to strengthen the foundation for speculative decoding and other future enhancements by improving the existing code's structure and maintainability. We appreciate your time and look forward to your feedback.

@jmamou
Copy link
Contributor

jmamou commented Nov 29, 2024

@zucchini-nlp feel free to review it :-)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for making the code more composable! LGTM as nothing changed in terms of functionality. Just left one comment as we had several PRs in parallel modifying the assisted generation code :)

src/transformers/generation/candidate_generator.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @keyboardAnt ! LGTM as nothing was changed except for composing code into smaller functions.

The only question is the use of self.prev_target_ids which I've mentioned was removed in prev PR explaining it is not needed. This is what I got from reviewing the prev PR

hmm, so to make sure, that means the prev impl when we checked prev_target_ids was not really correct? And we should check the length of already accepted input_ids

yes

So do we need to save prev token ids from target model or we can re-use the current token ids, because the current token ids in any case will have the prev token ids as prefix with new accepted tokens appended at the end

@jmamou
Copy link
Contributor

jmamou commented Dec 3, 2024

Thanks @keyboardAnt ! LGTM as nothing was changed except for composing code into smaller functions.

The only question is the use of self.prev_target_ids which I've mentioned was removed in prev PR explaining it is not needed. This is what I got from reviewing the prev PR

hmm, so to make sure, that means the prev impl when we checked prev_target_ids was not really correct? And we should check the length of already accepted input_ids

yes

So do we need to save prev token ids from target model or we can re-use the current token ids, because the current token ids in any case will have the prev token ids as prefix with new accepted tokens appended at the end

For UAG, we need just the number of tokens in self.prev_target_ids (prev_target_ids.shape[1]) in order to know which target tokens were added after the target validation of the last speculative iteration and to convert them to the relevant draft tokens (instead of converting from the beginning).
I propose to use similar approach as in USD when we store only the previous number of target tokens _prev_target_seq_len.

@zucchini-nlp
Copy link
Member

we need just the number of tokens in self.prev_target_ids (prev_target_ids.shape[1]

Ah that makes sense if we don't care about the actual token ids used previously, because the tokens should be available without storing them. Then we can indeed store only the prev length

@keyboardAnt
Copy link
Contributor Author

keyboardAnt commented Dec 5, 2024

we need just the number of tokens in self.prev_target_ids (prev_target_ids.shape[1]

Ah that makes sense if we don't care about the actual token ids used previously, because the tokens should be available without storing them. Then we can indeed store only the prev length

Thanks, @zucchini-nlp. I've replaced self.prev_target_ids with self.prev_target_ids_len and rebased main to keep the history clean. Can you please approve the awaiting workflows?

image

jmamou
jmamou previously requested changes Dec 5, 2024
Copy link
Contributor

@jmamou jmamou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keyboardAnt
please apply the same fix as at keyboardAnt#4.
else SD will not work when target and assistant are not on the same device

@jmamou
Copy link
Contributor

jmamou commented Dec 6, 2024

@keyboardAnt please apply the same fix as at keyboardAnt#4. else SD will not work when target and assistant are not on the same device

@keyboardAnt
PR #35116 fixes multi-gpu issue

@zucchini-nlp
Copy link
Member

You can now request review from the core maintainer to merge this :)

@jmamou
Copy link
Contributor

jmamou commented Dec 8, 2024

You can now request review from the core maintainer to merge this :)

@ArthurZucker

@ArthurZucker ArthurZucker self-requested a review December 9, 2024 12:34
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for the detailed explanations 🤗
Merging!

@ArthurZucker ArthurZucker dismissed jmamou’s stale review December 10, 2024 08:11

Need to dismiss it to merge!

@ArthurZucker
Copy link
Collaborator

Ah can you resolve conflicts? 🤗

@jmamou
Copy link
Contributor

jmamou commented Dec 10, 2024

Ah can you resolve conflicts? 🤗

@ArthurZucker done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants