-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Output dicts support in text generation pipeline #35092
Output dicts support in text generation pipeline #35092
Conversation
…of returning a error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi - firstly, thanks for the PR, but I notice that in the code there's a lot more than just adding output dict support! It seems like you're also changing the handling for batched generation a lot.
I can understand the rationale there, but I think it would be better to split that into a separate PR. In other words, make this into a smaller PR that just adds output dict support, with a second PR focused on batching. WDYT?
) | ||
|
||
for key, value in other_outputs.items(): | ||
if isinstance(value, (torch.Tensor, tf.Tensor)) and value.shape[0] == out_b: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line seems dangerous - remember that tf.Tensor
may not be available or imported on most systems! You may have to make this a more complicated conditional that only checks tf.Tensor
after checking is_tensorflow_available()
, to ensure no issues on Torch-only systems (i.e. most of them)
for key, value in other_outputs.items(): | ||
if isinstance(value, (list, tuple)): | ||
record[key] = value[idx] | ||
elif isinstance(value, (torch.Tensor, tf.Tensor)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue here
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@Rocketknight1 Hi. Thanks for looking at the PR and pointing that out. I agree with you; I got a bit ahead of myself. I've removed most of the changes and just made the changes to support dictionary output. I'd be happy if you could take a look. Thanks again! |
@Rocketknight1 Any update on this? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good to me now, and sorry for the delay! @gante since you're probably more familiar with return_dict_in_generate
, can you take a look and make sure it's all okay before I ping a core maintainer?
@gante Friendly reminder if you could take a look. Thank you very much! |
@Rocketknight1 Can't seem to get a hold of @gante. Any idea how to proceed? |
Hang on, I'll get him! |
@Rocketknight1 Any success in finding him? |
@jonasrohw I exist, but I'm being pinged in many places :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for expanding the capabilities of the pipeline :D
(fixed conflicts, will merge as soon as CI gets green) |
(failing tests are unrelated to this PR, and are being fixed in other PRs. Waiting for them to be merged first) |
* Support for generate_argument: return_dict_in_generate=True, instead of returning a error * fix: call test with return_dict_in_generate=True * fix: Only import torch if it is present * update: Encapsulate output_dict changes * fix: added back original comments --------- Co-authored-by: Joao Gante <[email protected]>
* Support for generate_argument: return_dict_in_generate=True, instead of returning a error * fix: call test with return_dict_in_generate=True * fix: Only import torch if it is present * update: Encapsulate output_dict changes * fix: added back original comments --------- Co-authored-by: Joao Gante <[email protected]>
* Support for generate_argument: return_dict_in_generate=True, instead of returning a error * fix: call test with return_dict_in_generate=True * fix: Only import torch if it is present * update: Encapsulate output_dict changes * fix: added back original comments --------- Co-authored-by: Joao Gante <[email protected]>
What does this PR do?
It is a minor fix to the text generation pipeline. When calling with the generation argument
return_dict_in_generate=True,
the code breaks because it does not expect a dict from themodel.generate(...)
call. If you want to view logits, for example, a dict is required. This fix can handlereturn_dict_in_generate=True
as a pipeline parameter.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Rocketknight1