[Whisper] Fix whisper tokenizer (#34537)
* handle single timestamp ending * include last timestamp token * handle single timestamp ending * avoid floating points arithm limitations * ensure float64 operations * new test * make fixup * make copies * handle edge case double tokens ending with different tokens * handle single timestamp ending * make fixup * handle conditioning on prev segments * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * [run-slow] whisper * don't call item() to avoid unnecessary sync * fix --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Co-authored-by: Eustache Le Bihan <eustlb@users.noreply.huggingface.co>
This commit is contained in:
@@ -2096,6 +2096,94 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_small_longform_timestamps_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
|
||||
model.to(torch_device)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]["array"]
|
||||
sampling_rate = dataset[0]["audio"]["sampling_rate"]
|
||||
|
||||
sample = [*sample[: 15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate :]]
|
||||
sample = np.array(sample)
|
||||
|
||||
input_features = processor(
|
||||
sample,
|
||||
sampling_rate=16_000,
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
).input_features
|
||||
|
||||
input_features = input_features.to(torch_device)
|
||||
generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"timestamp": (0.0, 6.38),
|
||||
},
|
||||
{
|
||||
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
"timestamp": (6.38, 11.32),
|
||||
},
|
||||
{
|
||||
"text": " He tells us that at this festive season of the year,",
|
||||
"timestamp": (11.32, 15.0),
|
||||
},
|
||||
{
|
||||
"text": " With Christmas and roast beef looming before us, similes drawn from eating and its results",
|
||||
"timestamp": (30.0, 36.76),
|
||||
},
|
||||
{
|
||||
"text": " occur most readily to the mind.",
|
||||
"timestamp": (36.76, 39.80),
|
||||
},
|
||||
{
|
||||
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
|
||||
"timestamp": (39.80, 45.36),
|
||||
},
|
||||
{
|
||||
"text": " can discover in it but little of rocky Ithaca.",
|
||||
"timestamp": (45.36, 49.0),
|
||||
},
|
||||
{
|
||||
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
|
||||
"timestamp": (49.0, 56.28),
|
||||
},
|
||||
{
|
||||
"text": " are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in",
|
||||
"timestamp": (56.28, 64.12),
|
||||
},
|
||||
{
|
||||
"text": " the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his",
|
||||
"timestamp": (64.12, 70.76),
|
||||
},
|
||||
{
|
||||
"text": " sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,",
|
||||
"timestamp": (70.76, 77.16),
|
||||
},
|
||||
{
|
||||
"text": " Next Man",
|
||||
"timestamp": (77.16, 78.16),
|
||||
},
|
||||
]
|
||||
|
||||
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
||||
|
||||
transcript_segments = [
|
||||
{
|
||||
"text": processor.decode(seg["tokens"], skip_special_tokens=True),
|
||||
"timestamp": (seg["start"].item(), seg["end"].item()),
|
||||
}
|
||||
for seg in generated_ids["segments"][0]
|
||||
]
|
||||
self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user