-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Fix whisper tokenizer #34537
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Justifications of the changes 🤗
@@ -629,7 +629,7 @@ def generate( | |||
cur_bsz=cur_bsz, | |||
batch_idx_map=batch_idx_map, | |||
) | |||
time_offset = seek * time_precision / input_stride | |||
time_offset = seek.to(torch.float64) * time_precision / input_stride |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original Whisper code base does such computations in float64. We need to ensure we do the same, especially wince we are comparing in the tests with the original Whisper outputs.
else: | ||
# we want to include the last timestamp token in the last segment to know it was no single ending | ||
slices[-1] += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the only way to know latter that we have a double token ending segment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you're offsetting the last timestamp by one when the last two tokens are timestamp ? let's say we have [..., T1, T2]
, you're doing [..., T1, T2+1]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slices
is a list of indexes along the generated sequence of tokens seek_sequence
. Let's say we have 100 tokens and it was not single ending, meaning last two tokens are timestamps [..., T1, T2]
. For this reason, slices[-1] = 99
yet when we will slice after the segments with seek_sequence[last_slice:current_slice]
, we want to make sure we include T2 in the slice (so that the we further know it is a double timestamp ending segment as explained in the PR's comment) → by adding 1 to last slice, we ensure last iteration will slice 'seek_sequence[last_slice:100]` and that T2 will get included
start_timestamp_position * time_precision + prev_segments_len * time_precision, | ||
end_timestamp_position * time_precision + prev_segments_len * time_precision, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's done this way in the original codebase: summing float64 after the multiplication with the position. We avoid this way annoying floating-point arithmetic issues.
61a8515
to
7d6f9b4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @eustlb ! Thanks for working on this!
It's quite surprising to me that the 2nd case wasn't handled well by the current integration! However, the proposed solution sounds ok, as long as it pass the new integration tests!
last_was_single_ending = i >= 2 and not ( | ||
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are token_ids
always in ascending order ? In that case, don't we always have token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
when entering the if token >= timestamp_begin
loop, if i>=2 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't we supposed to always have two subsequent timestamp tokens now that you've dealt with single timestamp token in the generation file ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily. After concatenating all the generated sequences for each 30sec segment and at the tokenizer decoding phase, we have two possibilities:
[..., <t1>, <t2>, <0.0>, ...]
(last segment was double timestamp ending)[..., <t1>, <0.0>, ...]
(last segment was single timestamp ending)
Note that the only reason we can differentiate those two cases here is thanks to the above slices[-1] += 1
that ensures <t2>
is included when we are not single timestamp ending. So in 2. we have token >= timestamp_begin
(the <0.0>
), token_ids[i - 1] >= timestamp_begin
but not token_ids[i - 2] >= timestamp_begin
for segments in active_segments: | ||
for seg in segments: | ||
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin: | ||
# the segment finishes with two timestamp tokens | ||
# we need to ignore the last timestamp token | ||
# see https://github.com/huggingface/transformers/pull/34537 | ||
seg["tokens"] = seg["tokens"][:-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a bit costly, I wonder if there's a cleaner way to compute all of this!
Why do we need to get rid of the last timestamp when preparing the input ids ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, we need to include the last timestamp in the case of double-ending timestamps (and not only the penultimate as done before) to enable the tokenizer to differentiate the two cases (single and double ending). Nevertheless, OAI does not have to worry about that because they don't concatenate all the tokens as we do, and for this when conditioning on the previous token, the last token in the case of double ending is omitted. To ensure we do the exact same, we need to remove it when preparing the decoder_input_ids
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that any segment can be a double-ending one, they all need to be checked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it is not that costly (complexity is O(batch_size * max number of segments)
). When testing, it has no measurable impact on inference speed (see long-form results here, our implem is on par with OAI's).
else: | ||
# we want to include the last timestamp token in the last segment to know it was no single ending | ||
slices[-1] += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you're offsetting the last timestamp by one when the last two tokens are timestamp ? let's say we have [..., T1, T2]
, you're doing [..., T1, T2+1]
?
Co-authored-by: Yoach Lacombe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR, Though you explained I am a bit outside the loop here 😉 sorry for the delay, good work!
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin | ||
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin | ||
idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2 | ||
end_timestamp_pos = sliced_tokens[idx_sliced_tokens].item() - timestamp_begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if relevant but each call to .item()
will do a device synch, maybe calling .item()
once after loop will be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
totally relevant, thanks !! This was unnecessary and is fixed in 5fba3e0
* handle single timestamp ending * include last timestamp token * handle single timestamp ending * avoid floating points arithm limitations * ensure float64 operations * new test * make fixup * make copies * handle edge case double tokens ending with different tokens * handle single timestamp ending * make fixup * handle conditioning on prev segments * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <[email protected]> * [run-slow] whisper * don't call item() to avoid unnecessary sync * fix --------- Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: Eustache Le Bihan <[email protected]>
* handle single timestamp ending * include last timestamp token * handle single timestamp ending * avoid floating points arithm limitations * ensure float64 operations * new test * make fixup * make copies * handle edge case double tokens ending with different tokens * handle single timestamp ending * make fixup * handle conditioning on prev segments * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <[email protected]> * [run-slow] whisper * don't call item() to avoid unnecessary sync * fix --------- Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: Eustache Le Bihan <[email protected]>
* handle single timestamp ending * include last timestamp token * handle single timestamp ending * avoid floating points arithm limitations * ensure float64 operations * new test * make fixup * make copies * handle edge case double tokens ending with different tokens * handle single timestamp ending * make fixup * handle conditioning on prev segments * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <[email protected]> * [run-slow] whisper * don't call item() to avoid unnecessary sync * fix --------- Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: Eustache Le Bihan <[email protected]>
segments.append( | ||
{ | ||
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision, | ||
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision, | ||
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this is the right place or if there's an open issue, but this line now crashes on macs / the mps backend because float64 is not supported on apple silicon. I had to downgrade to 4.46.3 - is there some other way to fix this?
What does this PR do?
Fixes #34472
What's happening
A special case in timestamp offsets was not handled.
When predicting timestamps, Whisper follows two strategies (see here):
Case 1 is correctly handled in
_retrieve_segments
that is responsible of this seeking process during generation, making the generated segments correct (see snippet below) while it is not correctly handled in the tokenizer.Snippet 🔧
Returns:
Changes
on the existing
Most important change here is the one in
_retrieve_segments
. When finishing generation with <|t1|><|t1|> (case 2), we must add it to the tokens of the returned segment. Indeed, otherwise, when we will decode the concatenated sequence, we won't have any way to differentiate a single ending token (case 1) from a double ending token.added
A new
test_small_longform_timestamps_generation
test since this edge case was not catched before. IMO it's an important one to add since we can compare directly with OAI's expected output so it's pretty robust. Code to reproduce the expected output can be found here. This new test catch an edge case we did not test before: single timestamp ending (here at 15.0s), triggering seeking to end of segment (here 30.0s). Moreover, we test here both the segments throughgenerated_ids["segments"]
that are built at generation time and the offsets withprocessor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
that are reconstructed from the concatenated segments.