Add time stamps for wav2vec2 with lm (#15854)

* [Wav2Vec2 With LM] add timestamps

* correct

* correct

* Apply suggestions from code review

* correct

* Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

* make style

* Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* make style

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2022-03-01 17:03:05 +01:00
committed by GitHub
parent 3f2e636850
commit e064f08150
4 changed files with 215 additions and 16 deletions

View File

@@ -20,13 +20,15 @@ import unittest
from multiprocessing import get_context
from pathlib import Path
import datasets
import numpy as np
from datasets import load_dataset
from transformers import AutoProcessor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_pyctcdecode
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
@@ -35,6 +37,10 @@ if is_pyctcdecode_available():
from huggingface_hub import snapshot_download
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput
if is_torch_available():
from transformers import Wav2Vec2ForCTC
@require_pyctcdecode
@@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_auto = processor_auto.batch_decode(logits)
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
@staticmethod
def get_from_offsets(offsets, key):
retrieved_list = [d[key] for d in offsets]
return retrieved_list
def test_offsets_integration_fast(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
logits = self._get_dummy_logits()[0]
outputs = processor.decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2)
self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["<s>", "<s>", "</s>"])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5])
def test_offsets_integration_fast_batch(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
logits = self._get_dummy_logits()
outputs = processor.batch_decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2)
self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
self.assertListEqual(
[" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text
)
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["<s>", "<s>", "</s>"])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5])
@slow
@require_torch
@require_torchaudio
def test_word_time_stamp_integration(self):
import torch
ds = load_dataset("common_voice", "en", split="train", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
ds_iter = iter(ds)
sample = next(ds_iter)
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits.cpu().numpy()
output = processor.decode(logits[0], output_word_offsets=True)
time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
word_time_stamps = [
{
"start_time": d["start_offset"] * time_offset,
"end_time": d["end_offset"] * time_offset,
"word": d["word"],
}
for d in output["word_offsets"]
]
EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
# output words
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
# output times
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
# fmt: off
self.assertListEqual(
start_times,
[
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
],
)
self.assertListEqual(
end_times,
[
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
],
)
# fmt: on