Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Fix whisper tokenizer #34537

Merged
merged 27 commits into from
Dec 5, 2024
Merged

Conversation

eustlb
Copy link
Contributor

@eustlb eustlb commented Oct 31, 2024

What does this PR do?

Fixes #34472

What's happening

A special case in timestamp offsets was not handled.
When predicting timestamps, Whisper follows two strategies (see here):

  1. single timestamp at the end: predicted sequence ends with <|t1|> → no speech after t1, seek to* end of the 30sec segment
  2. double timestamp at the end: predicted sequence ends with <|t1|><|t2|> → seek to* t1

*Note: Whisper works on 30sec windows of audio. Above "seek to" means sliding this 30sec window to a new start position.

Case 1 is correctly handled in _retrieve_segments that is responsible of this seeking process during generation, making the generated segments correct (see snippet below) while it is not correctly handled in the tokenizer.

Snippet 🔧
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import numpy as np

# load model + processor
processor = AutoProcessor.from_pretrained("openai/whisper-small.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")

# load dataset
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sampling_rate = dataset[0]["audio"]["sampling_rate"]

sample = [*sample[:15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate:]]
sample = np.array(sample)

# pre-process
inputs = processor(
    sample,
    sampling_rate=16_000,
    padding="longest",
    truncation=False,
    return_attention_mask=True,
    return_tensors="pt",
)

# inference
output = model.generate(**inputs, return_timestamps=True, return_segments=True)

# this is correct
print("=" * 10, "this is correct", "=" * 10)
for seg in output["segments"][0]:
    print(f"{seg['start'].item():.2f} -> {seg['end'].item():.2f}: {processor.decode(seg['tokens'])}")

# this is wrong
# pass token ids to processor's decode method
print("=" * 10, "this is wrong", "=" * 10)
result = processor.batch_decode(output["sequences"], skip_special_tokens=True, output_offsets=True)
print("\n".join([f"{chunk['timestamp'][0]:.2f} -> {chunk['timestamp'][1]:.2f} : {chunk['text']}" for chunk in result[0]["offsets"]]))

Returns:

========== this is correct ==========
0.00 -> 6.38:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.
6.38 -> 11.32:  Nor is Mr. Quilter's manner less interesting than his matter.
11.32 -> 15.00:  He tells us that at this festive season of the year,
30.00 -> 36.76:  With Christmas and roast beef looming before us, similes drawn from eating and its results
36.76 -> 39.80:  occur most readily to the mind.
39.80 -> 45.38:  He has grave doubts whether Sir Frederick Layton's work is really Greek after all and
45.38 -> 49.00:  can discover in it but little of rocky Ithaca.
49.00 -> 56.28:  Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles
56.28 -> 64.12:  are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in
64.12 -> 70.76:  the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his
70.76 -> 77.16:  sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,
77.16 -> 78.16:  Next Man
========== this is wrong ==========
0.00 -> 6.38 :  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.
6.38 -> 11.32 :  Nor is Mr. Quilter's manner less interesting than his matter.
11.32 -> 15.00 :  He tells us that at this festive season of the year,
15.00 -> 21.76 :  With Christmas and roast beef looming before us, similes drawn from eating and its results
21.76 -> 24.80 :  occur most readily to the mind.
24.80 -> 30.38 :  He has grave doubts whether Sir Frederick Layton's work is really Greek after all and
30.38 -> 34.00 :  can discover in it but little of rocky Ithaca.
34.00 -> 41.28 :  Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles
41.28 -> 49.12 :  are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in
49.12 -> 55.76 :  the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his
55.76 -> 62.16 :  sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,
62.16 -> 63.16 :  Next Man

Changes

on the existing

Most important change here is the one in _retrieve_segments. When finishing generation with <|t1|><|t1|> (case 2), we must add it to the tokens of the returned segment. Indeed, otherwise, when we will decode the concatenated sequence, we won't have any way to differentiate a single ending token (case 1) from a double ending token.

added

A new test_small_longform_timestamps_generation test since this edge case was not catched before. IMO it's an important one to add since we can compare directly with OAI's expected output so it's pretty robust. Code to reproduce the expected output can be found here. This new test catch an edge case we did not test before: single timestamp ending (here at 15.0s), triggering seeking to end of segment (here 30.0s). Moreover, we test here both the segments through generated_ids["segments"] that are built at generation time and the offsets with processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True) that are reconstructed from the concatenated segments.

@eustlb eustlb changed the title Fix whispertokenizer [Whisper] Fix whisper tokenizer Oct 31, 2024
@eustlb eustlb changed the title [Whisper] Fix whisper tokenizer [WPI] [Whisper] Fix whisper tokenizer Oct 31, 2024
@eustlb eustlb changed the title [WPI] [Whisper] Fix whisper tokenizer [WIP] [Whisper] Fix whisper tokenizer Oct 31, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor Author

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Justifications of the changes 🤗

@@ -629,7 +629,7 @@ def generate(
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
time_offset = seek * time_precision / input_stride
time_offset = seek.to(torch.float64) * time_precision / input_stride
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original Whisper code base does such computations in float64. We need to ensure we do the same, especially wince we are comparing in the tests with the original Whisper outputs.

Comment on lines +1802 to +1804
else:
# we want to include the last timestamp token in the last segment to know it was no single ending
slices[-1] += 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the only way to know latter that we have a double token ending segment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're offsetting the last timestamp by one when the last two tokens are timestamp ? let's say we have [..., T1, T2], you're doing [..., T1, T2+1] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slices is a list of indexes along the generated sequence of tokens seek_sequence. Let's say we have 100 tokens and it was not single ending, meaning last two tokens are timestamps [..., T1, T2]. For this reason, slices[-1] = 99 yet when we will slice after the segments with seek_sequence[last_slice:current_slice], we want to make sure we include T2 in the slice (so that the we further know it is a double timestamp ending segment as explained in the PR's comment) → by adding 1 to last slice, we ensure last iteration will slice 'seek_sequence[last_slice:100]` and that T2 will get included

Comment on lines +626 to +627
start_timestamp_position * time_precision + prev_segments_len * time_precision,
end_timestamp_position * time_precision + prev_segments_len * time_precision,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done this way in the original codebase: summing float64 after the multiplication with the position. We avoid this way annoying floating-point arithmetic issues.

tests/models/whisper/test_modeling_whisper.py Outdated Show resolved Hide resolved
@eustlb eustlb marked this pull request as ready for review October 31, 2024 19:13
@eustlb eustlb changed the title [WIP] [Whisper] Fix whisper tokenizer [Whisper] Fix whisper tokenizer Oct 31, 2024
@eustlb eustlb force-pushed the fix-whispertokenizer branch from 61a8515 to 7d6f9b4 Compare November 4, 2024 17:34
Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @eustlb ! Thanks for working on this!

It's quite surprising to me that the 2nd case wasn't handled well by the current integration! However, the proposed solution sounds ok, as long as it pass the new integration tests!

Comment on lines +551 to +552
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are token_ids always in ascending order ? In that case, don't we always have token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin when entering the if token >= timestamp_begin loop, if i>=2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we supposed to always have two subsequent timestamp tokens now that you've dealt with single timestamp token in the generation file ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. After concatenating all the generated sequences for each 30sec segment and at the tokenizer decoding phase, we have two possibilities:

  1. [..., <t1>, <t2>, <0.0>, ...] (last segment was double timestamp ending)
  2. [..., <t1>, <0.0>, ...] (last segment was single timestamp ending)

Note that the only reason we can differentiate those two cases here is thanks to the above slices[-1] += 1 that ensures <t2> is included when we are not single timestamp ending. So in 2. we have token >= timestamp_begin (the <0.0>), token_ids[i - 1] >= timestamp_begin but not token_ids[i - 2] >= timestamp_begin

src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
Comment on lines +1693 to +1699
for segments in active_segments:
for seg in segments:
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
seg["tokens"] = seg["tokens"][:-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit costly, I wonder if there's a cleaner way to compute all of this!

Why do we need to get rid of the last timestamp when preparing the input ids ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, we need to include the last timestamp in the case of double-ending timestamps (and not only the penultimate as done before) to enable the tokenizer to differentiate the two cases (single and double ending). Nevertheless, OAI does not have to worry about that because they don't concatenate all the tokens as we do, and for this when conditioning on the previous token, the last token in the case of double ending is omitted. To ensure we do the exact same, we need to remove it when preparing the decoder_input_ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that any segment can be a double-ending one, they all need to be checked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it is not that costly (complexity is O(batch_size * max number of segments)). When testing, it has no measurable impact on inference speed (see long-form results here, our implem is on par with OAI's).

Comment on lines +1802 to +1804
else:
# we want to include the last timestamp token in the last segment to know it was no single ending
slices[-1] += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're offsetting the last timestamp by one when the last two tokens are timestamp ? let's say we have [..., T1, T2], you're doing [..., T1, T2+1] ?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR, Though you explained I am a bit outside the loop here 😉 sorry for the delay, good work!

Comment on lines 1826 to 1828
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
end_timestamp_pos = sliced_tokens[idx_sliced_tokens].item() - timestamp_begin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if relevant but each call to .item() will do a device synch, maybe calling .item() once after loop will be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally relevant, thanks !! This was unnecessary and is fixed in 5fba3e0

@eustlb eustlb merged commit 54aae12 into huggingface:main Dec 5, 2024
26 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* handle single timestamp ending

* include last timestamp token

* handle single timestamp ending

* avoid floating points arithm limitations

* ensure float64 operations

* new test

* make fixup

* make copies

* handle edge case double tokens ending with different tokens

* handle single timestamp ending

* make fixup

* handle conditioning on prev segments

* fix

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Yoach Lacombe <[email protected]>

* [run-slow] whisper

* don't call item() to avoid unnecessary sync

* fix

---------

Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: Eustache Le Bihan <[email protected]>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* handle single timestamp ending

* include last timestamp token

* handle single timestamp ending

* avoid floating points arithm limitations

* ensure float64 operations

* new test

* make fixup

* make copies

* handle edge case double tokens ending with different tokens

* handle single timestamp ending

* make fixup

* handle conditioning on prev segments

* fix

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Yoach Lacombe <[email protected]>

* [run-slow] whisper

* don't call item() to avoid unnecessary sync

* fix

---------

Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: Eustache Le Bihan <[email protected]>
shyshin pushed a commit to shyshin/transformers that referenced this pull request Dec 9, 2024
* handle single timestamp ending

* include last timestamp token

* handle single timestamp ending

* avoid floating points arithm limitations

* ensure float64 operations

* new test

* make fixup

* make copies

* handle edge case double tokens ending with different tokens

* handle single timestamp ending

* make fixup

* handle conditioning on prev segments

* fix

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Yoach Lacombe <[email protected]>

* [run-slow] whisper

* don't call item() to avoid unnecessary sync

* fix

---------

Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: Eustache Le Bihan <[email protected]>
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is the right place or if there's an open issue, but this line now crashes on macs / the mps backend because float64 is not supported on apple silicon. I had to downgrade to 4.46.3 - is there some other way to fix this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

WhisperTokenizer decode is offsetting timestamps incorrectly
5 participants