Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update config validation #34726

Conversation

Manalelaidouni
Copy link
Contributor

@Manalelaidouni Manalelaidouni commented Nov 14, 2024

What does this PR do?

Fixes #33690

This PR adds extra validation for generation parameters to fail early when invalid values are passed and not let things fail silently as suggested in the issue.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante @zucchini-nlp

@Manalelaidouni
Copy link
Contributor Author

The failing tests seem unrelated to my changes and all tests passed locally for me, is there something I'm missing here?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks for working on this! The tests failing are not related and re-trigerring run should help :)

@Manalelaidouni
Copy link
Contributor Author

Hey @zucchini-nlp thanks! I’ve added extra validations for the remaining generation parameters.

I also validated bad_words_ids inside GenerationConfig.validate() , to test it before running the forward pass, not at the end of it (when its NoBadWordsLogitsProcessor is executed).

For now, I’ve left the _validate_arguments() method in NoBadWordsLogitsProcessor intact, I was considering removing it since it’s now redundant, but I’d like get your green light before proceeding; If approved I would apply the same approach for sequence_bias and adjust their tests accordingly :)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, an extensive validation list! Thanks!

I was considering removing it since it’s now redundant

No, I think we still need it because the logits processors can be initialized before generate() call and doesn't necessarily require a generation config.

Adding a few tests should be enough at this point :)

@Manalelaidouni
Copy link
Contributor Author

Thanks @zucchini-nlp ! I believe all generation parameters are validated now

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Manalelaidouni cool, LGTM! As said before adding a few tests should be enough at this point :)

You can add them in tests/generation/test_configuration_utils.py to make sure that incorrect variable through an error when passed to generation config. LMK if you need help with tests. Then feel free to tag the core maintainer @ ArthurZucker for final review

@Manalelaidouni
Copy link
Contributor Author

Manalelaidouni commented Nov 29, 2024

Thank you @zucchini-nlp for reviewing, when adding tests I also made updates to some existing ones:

  • So an older PR changed the sequence_bias format because it didn’t serialize properly, but the docstring and code still reference the dictionary format to users, which is confusing and since _convert_list_arguments_into_dict() in SequenceBiasLogitsProcessor is used for internal backwards compatibility, the _validate_arguments() should raise a ValueError if the input is not in the list format, right?

  • I also adjusted the forced_bos_token_id and forced_eos_token_id and added missing get_generation_mode, the rest of parameters I validated in the config seem be already tested.

Finally, I’m not entirely sure about how to address bad_words_ids in the failing CI/CD tests, any guidance here would be appreciated, thank you again!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for iterating on this!

since _convert_list_arguments_into_dict() in SequenceBiasLogitsProcessor is used for background compatibilty, the _validate_arguments() should raise a ValueError if the input is not in the list format,

No, if we are keeping BC and the changes were done recently, we will not raise error if users pass in the old-style dict. It should be converted to the correct format to be used by logit processor at the end. But I see what you mean here, we won't be able to serialize the old dict format, so that means we'd need to convert_to_list in configuration here, which ensure that specific values are converted to serializable format in below code. Let's keep this for another PR, as the current one is getting big and should only check that correct values are passed

def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:

bad_words_ids in the failing CI/CD tests

The token id in these args can be also a numpy int, see how we check the values here

if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):

@Manalelaidouni Manalelaidouni force-pushed the validate-generation-params branch from e4cccd6 to 7250e2f Compare November 30, 2024 08:41
@Manalelaidouni
Copy link
Contributor Author

cc @ArthurZucker whenever you have a moment, I'd appreciate your review, I wanted to see those CI tests green 😅 but it seems unrelated. Thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the PR, this seems to be quite inflated 😓
This is mostly not scalable 😢 so we should find another tool / effiecient way to define boundaries ands etc. Dataclasses might help!

@@ -587,17 +589,213 @@ def validate(self, is_init=False):
"""

# Validation of individual attributes
pos_int_args = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pos_int_args = [
pos_init_args = [

was probably intended?

@Manalelaidouni
Copy link
Contributor Author

Manalelaidouni commented Jan 6, 2025

Thanks for reviewing this @ArthurZucker!

  • dataclasses don't inherently validate data, it won’t raise an error if passed an invalid data type, so we would need to implement the validation logic manually anyway, relies on passing default model configs via kwargs which makes fully converting it to a dataclass impractical.

  • Since you mentioned using a tool, a much simpler solution would to use attrs and it’s easily scalable, but it adds an extra dependency (Gante advised against using Pydantic I assume because it significantly increases runtime).

  • If we don't want to add dependencies, we can refactor the logic to a mapping approach which is scalable, something like this:


class _ConfigValidator:

    VALIDATORS = {
        "pos_int_arg": (lambda x: isinstance(x, int) and x >= 0,
            "`{arg}` must be a positive integer, but got {value}."),

        "probability_arg": (lambda x: isinstance(x, (int, float)) and 0.0 <= x <= 1.0,
            "`{arg}` must be a float between 0 and 1, but got {value}."),
          ... }

    CONFIG_TYPES = {
        "top_k": "pos_int_arg",
        "top_p": "probability_arg",
         ...}


class GenerationConfig:
    ....

    def validate(self):
        for arg, value in self.__dict__.items():
            if arg in _ConfigValidator.CONFIG_TYPES:
                key = _ConfigValidator.CONFIG_TYPES[arg]
                validator, error_message = _ConfigValidator.VALIDATORS[key]
                if value is not None and not validator(value):
                    raise ValueError(error_message.format(arg=arg, value=value))

_ConfigValidator would hold also more complex validation for consistency and keep GenerationConfig.validate() less inflated for any future generation arguments. Let me know what you think before I move forward :)

@ArthurZucker
Copy link
Collaborator

I really don't mind using pydantic, it's become quite important and we have so many configs that we need to validate! Do you mind drafting a PR for that? 🤗

@Manalelaidouni
Copy link
Contributor Author

Hey @ArthurZucker, apologies for the delayed response, I just opened a PR for it!

@gante
Copy link
Member

gante commented Feb 14, 2025

@Manalelaidouni this PR is orthogonal to the pydantic PR, correct? Or should we merge one before the other?

@Manalelaidouni
Copy link
Contributor Author

Manalelaidouni commented Feb 14, 2025

@gante they're indeed orthogonal this was the original plan, so I'll close this PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Silent failure in generation parameters
5 participants