Add time stamps for wav2vec2 with lm (#15854)
* [Wav2Vec2 With LM] add timestamps * correct * correct * Apply suggestions from code review * correct * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py * make style * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * make style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3f2e636850
commit
e064f08150
@@ -97,6 +97,8 @@ WAV2VEC2_KWARGS_DOCSTRING = r"""
|
|||||||
Whether or not to print more information and warnings.
|
Whether or not to print more information and warnings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ListOfDict = List[Dict[str, Union[int, str]]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
||||||
@@ -106,18 +108,18 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
text (list of `str` or `str`):
|
text (list of `str` or `str`):
|
||||||
Decoded logits in text from. Usually the speech transcription.
|
Decoded logits in text from. Usually the speech transcription.
|
||||||
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||||
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||||
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||||
produced text.
|
produced text.
|
||||||
word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||||
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
||||||
can be used to compute time stamps for each word.
|
can be used to compute time stamps for each word.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: Union[List[str], str]
|
text: Union[List[str], str]
|
||||||
char_offsets: List[Dict[str, Union[float, str]]] = None
|
char_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||||
word_offsets: List[Dict[str, Union[float, str]]] = None
|
word_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
|
||||||
|
|
||||||
|
|
||||||
|
ListOfDict = List[Dict[str, Union[int, str]]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -74,14 +77,14 @@ class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
text (list of `str` or `str`):
|
text (list of `str` or `str`):
|
||||||
Decoded logits in text from. Usually the speech transcription.
|
Decoded logits in text from. Usually the speech transcription.
|
||||||
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||||
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||||
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||||
produced text.
|
produced text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: Union[List[str], str]
|
text: Union[List[str], str]
|
||||||
char_offsets: List[Dict[str, Union[float, str]]] = None
|
char_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import os
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -34,23 +34,30 @@ if TYPE_CHECKING:
|
|||||||
from ...tokenization_utils import PreTrainedTokenizerBase
|
from ...tokenization_utils import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
|
||||||
|
ListOfDict = List[Dict[str, Union[int, str]]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Output type of [`Wav2Vec2DecoderWithLM`], with transcription.
|
Output type of [`Wav2Vec2DecoderWithLM`], with transcription.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (list of `str`):
|
text (list of `str` or `str`):
|
||||||
Decoded logits in text from. Usually the speech transcription.
|
Decoded logits in text from. Usually the speech transcription.
|
||||||
logit_score (list of `float`):
|
logit_score (list of `float` or `float`):
|
||||||
Total logit score of the beam associated with produced text.
|
Total logit score of the beam associated with produced text.
|
||||||
lm_score (list of `float`):
|
lm_score (list of `float`):
|
||||||
Fused lm_score of the beam associated with produced text.
|
Fused lm_score of the beam associated with produced text.
|
||||||
|
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||||
|
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
||||||
|
can be used to compute time stamps for each word.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: Union[List[str], str]
|
text: Union[List[str], str]
|
||||||
logit_score: Union[List[float], float] = None
|
logit_score: Union[List[float], float] = None
|
||||||
lm_score: Union[List[float], float] = None
|
lm_score: Union[List[float], float] = None
|
||||||
|
word_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||||
@@ -232,6 +239,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
beta: Optional[float] = None,
|
beta: Optional[float] = None,
|
||||||
unk_score_offset: Optional[float] = None,
|
unk_score_offset: Optional[float] = None,
|
||||||
lm_score_boundary: Optional[bool] = None,
|
lm_score_boundary: Optional[bool] = None,
|
||||||
|
output_word_offsets: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Batch decode output logits to audio transcription with language model support.
|
Batch decode output logits to audio transcription with language model support.
|
||||||
@@ -267,6 +275,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
Amount of log score offset for unknown tokens
|
Amount of log score offset for unknown tokens
|
||||||
lm_score_boundary (`bool`, *optional*):
|
lm_score_boundary (`bool`, *optional*):
|
||||||
Whether to have kenlm respect boundaries when scoring
|
Whether to have kenlm respect boundaries when scoring
|
||||||
|
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||||
|
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
||||||
|
better understand how to make use of `output_word_offsets`.
|
||||||
|
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
|
||||||
|
output.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||||
@@ -310,13 +330,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
pool.close()
|
pool.close()
|
||||||
|
|
||||||
# extract text and scores
|
# extract text and scores
|
||||||
batch_texts, logit_scores, lm_scores = [], [], []
|
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
|
||||||
for d in decoded_beams:
|
for d in decoded_beams:
|
||||||
batch_texts.append(d[0][0])
|
batch_texts.append(d[0][0])
|
||||||
logit_scores.append(d[0][-2])
|
logit_scores.append(d[0][-2])
|
||||||
lm_scores.append(d[0][-1])
|
lm_scores.append(d[0][-1])
|
||||||
# more output features will be added in the future
|
word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]])
|
||||||
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
|
|
||||||
|
word_offsets = word_offsets if output_word_offsets else None
|
||||||
|
|
||||||
|
return Wav2Vec2DecoderWithLMOutput(
|
||||||
|
text=batch_texts, logit_score=logit_scores, lm_score=lm_scores, word_offsets=word_offsets
|
||||||
|
)
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@@ -330,6 +355,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
beta: Optional[float] = None,
|
beta: Optional[float] = None,
|
||||||
unk_score_offset: Optional[float] = None,
|
unk_score_offset: Optional[float] = None,
|
||||||
lm_score_boundary: Optional[bool] = None,
|
lm_score_boundary: Optional[bool] = None,
|
||||||
|
output_word_offsets: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Decode output logits to audio transcription with language model support.
|
Decode output logits to audio transcription with language model support.
|
||||||
@@ -357,11 +383,65 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
Amount of log score offset for unknown tokens
|
Amount of log score offset for unknown tokens
|
||||||
lm_score_boundary (`bool`, *optional*):
|
lm_score_boundary (`bool`, *optional*):
|
||||||
Whether to have kenlm respect boundaries when scoring
|
Whether to have kenlm respect boundaries when scoring
|
||||||
|
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||||
|
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
||||||
|
better understand how to make use of `output_word_offsets`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||||
|
|
||||||
"""
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Let's see how to retrieve time steps for a model
|
||||||
|
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import datasets
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> # import model, feature extractor, tokenizer
|
||||||
|
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||||
|
|
||||||
|
>>> # load first sample of English common_voice
|
||||||
|
>>> dataset = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||||
|
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||||
|
>>> dataset_iter = iter(dataset)
|
||||||
|
>>> sample = next(dataset_iter)
|
||||||
|
|
||||||
|
>>> # forward sample through model to get greedily predicted transcription ids
|
||||||
|
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(input_values).logits[0].cpu().numpy()
|
||||||
|
|
||||||
|
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
|
||||||
|
>>> outputs = tokenizer.decode(logits, output_word_offsets=True)
|
||||||
|
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
||||||
|
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
|
||||||
|
|
||||||
|
>>> word_offsets = [
|
||||||
|
... {
|
||||||
|
... "word": d["word"],
|
||||||
|
... "start_time": d["start_offset"] * time_offset,
|
||||||
|
... "end_time": d["end_offset"] * time_offset,
|
||||||
|
... }
|
||||||
|
... for d in outputs.word_offsets
|
||||||
|
... ]
|
||||||
|
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
||||||
|
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||||
|
>>> word_offset
|
||||||
|
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
|
||||||
|
>>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A',
|
||||||
|
>>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ...
|
||||||
|
```"""
|
||||||
|
|
||||||
from pyctcdecode.constants import (
|
from pyctcdecode.constants import (
|
||||||
DEFAULT_BEAM_WIDTH,
|
DEFAULT_BEAM_WIDTH,
|
||||||
DEFAULT_HOTWORD_WEIGHT,
|
DEFAULT_HOTWORD_WEIGHT,
|
||||||
@@ -390,9 +470,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
hotword_weight=hotword_weight,
|
hotword_weight=hotword_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
word_offsets = None
|
||||||
|
if output_word_offsets:
|
||||||
|
word_offsets = [
|
||||||
|
{"word": word, "start_offset": start_offset, "end_offset": end_offset}
|
||||||
|
for word, (start_offset, end_offset) in decoded_beams[0][2]
|
||||||
|
]
|
||||||
|
|
||||||
# more output features will be added in the future
|
# more output features will be added in the future
|
||||||
return Wav2Vec2DecoderWithLMOutput(
|
return Wav2Vec2DecoderWithLMOutput(
|
||||||
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
|
text=decoded_beams[0][0],
|
||||||
|
logit_score=decoded_beams[0][-2],
|
||||||
|
lm_score=decoded_beams[0][-1],
|
||||||
|
word_offsets=word_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@@ -20,13 +20,15 @@ import unittest
|
|||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
|
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available
|
||||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||||
from transformers.testing_utils import require_pyctcdecode
|
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
|
||||||
|
|
||||||
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
|
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
|
||||||
|
|
||||||
@@ -35,6 +37,10 @@ if is_pyctcdecode_available():
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from pyctcdecode import BeamSearchDecoderCTC
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||||
|
from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import Wav2Vec2ForCTC
|
||||||
|
|
||||||
|
|
||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
@@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
decoded_auto = processor_auto.batch_decode(logits)
|
decoded_auto = processor_auto.batch_decode(logits)
|
||||||
|
|
||||||
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_from_offsets(offsets, key):
|
||||||
|
retrieved_list = [d[key] for d in offsets]
|
||||||
|
return retrieved_list
|
||||||
|
|
||||||
|
def test_offsets_integration_fast(self):
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||||
|
logits = self._get_dummy_logits()[0]
|
||||||
|
|
||||||
|
outputs = processor.decode(logits, output_word_offsets=True)
|
||||||
|
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||||
|
self.assertTrue(len(outputs.keys()), 2)
|
||||||
|
self.assertTrue("text" in outputs)
|
||||||
|
self.assertTrue("word_offsets" in outputs)
|
||||||
|
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
|
||||||
|
|
||||||
|
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["<s>", "<s>", "</s>"])
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4])
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5])
|
||||||
|
|
||||||
|
def test_offsets_integration_fast_batch(self):
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||||
|
logits = self._get_dummy_logits()
|
||||||
|
|
||||||
|
outputs = processor.batch_decode(logits, output_word_offsets=True)
|
||||||
|
|
||||||
|
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||||
|
self.assertTrue(len(outputs.keys()), 2)
|
||||||
|
self.assertTrue("text" in outputs)
|
||||||
|
self.assertTrue("word_offsets" in outputs)
|
||||||
|
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
[" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text
|
||||||
|
)
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["<s>", "<s>", "</s>"])
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4])
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torchaudio
|
||||||
|
def test_word_time_stamp_integration(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||||
|
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||||
|
ds_iter = iter(ds)
|
||||||
|
sample = next(ds_iter)
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||||
|
|
||||||
|
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||||
|
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values).logits.cpu().numpy()
|
||||||
|
|
||||||
|
output = processor.decode(logits[0], output_word_offsets=True)
|
||||||
|
|
||||||
|
time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
|
||||||
|
word_time_stamps = [
|
||||||
|
{
|
||||||
|
"start_time": d["start_offset"] * time_offset,
|
||||||
|
"end_time": d["end_offset"] * time_offset,
|
||||||
|
"word": d["word"],
|
||||||
|
}
|
||||||
|
for d in output["word_offsets"]
|
||||||
|
]
|
||||||
|
|
||||||
|
EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
|
||||||
|
|
||||||
|
# output words
|
||||||
|
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
|
||||||
|
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
|
||||||
|
|
||||||
|
# output times
|
||||||
|
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
|
||||||
|
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertListEqual(
|
||||||
|
start_times,
|
||||||
|
[
|
||||||
|
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
end_times,
|
||||||
|
[
|
||||||
|
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|||||||
Reference in New Issue
Block a user