Add time stamps for wav2vec2 with lm (#15854)
* [Wav2Vec2 With LM] add timestamps * correct * correct * Apply suggestions from code review * correct * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py * make style * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * make style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3f2e636850
commit
e064f08150
@@ -20,13 +20,15 @@ import unittest
|
||||
from multiprocessing import get_context
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_pyctcdecode
|
||||
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
|
||||
|
||||
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
@@ -35,6 +37,10 @@ if is_pyctcdecode_available():
|
||||
from huggingface_hub import snapshot_download
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import Wav2Vec2ForCTC
|
||||
|
||||
|
||||
@require_pyctcdecode
|
||||
@@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
decoded_auto = processor_auto.batch_decode(logits)
|
||||
|
||||
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
||||
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets_integration_fast(self):
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
logits = self._get_dummy_logits()[0]
|
||||
|
||||
outputs = processor.decode(logits, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs.keys()), 2)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
|
||||
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["<s>", "<s>", "</s>"])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5])
|
||||
|
||||
def test_offsets_integration_fast_batch(self):
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
outputs = processor.batch_decode(logits, output_word_offsets=True)
|
||||
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs.keys()), 2)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
|
||||
|
||||
self.assertListEqual(
|
||||
[" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text
|
||||
)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["<s>", "<s>", "</s>"])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_word_time_stamp_integration(self):
|
||||
import torch
|
||||
|
||||
ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
ds_iter = iter(ds)
|
||||
sample = next(ds_iter)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
|
||||
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values).logits.cpu().numpy()
|
||||
|
||||
output = processor.decode(logits[0], output_word_offsets=True)
|
||||
|
||||
time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
|
||||
word_time_stamps = [
|
||||
{
|
||||
"start_time": d["start_offset"] * time_offset,
|
||||
"end_time": d["end_offset"] * time_offset,
|
||||
"word": d["word"],
|
||||
}
|
||||
for d in output["word_offsets"]
|
||||
]
|
||||
|
||||
EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
|
||||
|
||||
# output words
|
||||
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
|
||||
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
|
||||
|
||||
# output times
|
||||
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
|
||||
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
|
||||
|
||||
# fmt: off
|
||||
self.assertListEqual(
|
||||
start_times,
|
||||
[
|
||||
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
|
||||
],
|
||||
)
|
||||
|
||||
self.assertListEqual(
|
||||
end_times,
|
||||
[
|
||||
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user