Adding timestamps for CTC with LM in ASR pipeline. (#15863)
* Adding timestamps for CTC with LM in ASR pipeline. * iRemove print. * Nit change.
This commit is contained in:
@@ -353,7 +353,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
word = char
|
word = char
|
||||||
|
|
||||||
last_state = state
|
last_state = state
|
||||||
if state == "WORD":
|
if last_state == "WORD":
|
||||||
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||||
|
|
||||||
return word_offsets
|
return word_offsets
|
||||||
|
|||||||
@@ -313,8 +313,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# Optional return types
|
# Optional return types
|
||||||
optional = {}
|
optional = {}
|
||||||
|
|
||||||
if return_timestamps and self.type != "ctc":
|
if return_timestamps and self.type == "seq2seq":
|
||||||
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
|
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
|
||||||
|
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
||||||
|
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
|
||||||
|
|
||||||
final_items = []
|
final_items = []
|
||||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||||
@@ -335,32 +337,41 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if self.type == "ctc_with_lm":
|
if self.type == "ctc_with_lm":
|
||||||
if decoder_kwargs is None:
|
if decoder_kwargs is None:
|
||||||
decoder_kwargs = {}
|
decoder_kwargs = {}
|
||||||
text = self.decoder.decode_beams(items, **decoder_kwargs)[0][0]
|
beams = self.decoder.decode_beams(items, **decoder_kwargs)
|
||||||
|
text = beams[0][0]
|
||||||
|
if return_timestamps:
|
||||||
|
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
|
||||||
|
# pre-existing code later
|
||||||
|
chunk_offset = beams[0][2]
|
||||||
|
word_offsets = []
|
||||||
|
for word, (start_offset, end_offset) in chunk_offset:
|
||||||
|
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
skip_special_tokens = self.type != "ctc"
|
skip_special_tokens = self.type != "ctc"
|
||||||
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
||||||
if return_timestamps:
|
if return_timestamps:
|
||||||
if return_timestamps == "char":
|
char_offsets = self.tokenizer.decode(
|
||||||
decoded = self.tokenizer.decode(
|
|
||||||
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
||||||
|
)["char_offsets"]
|
||||||
|
if return_timestamps == "word":
|
||||||
|
word_offsets = self.tokenizer._get_word_offsets(
|
||||||
|
char_offsets, self.tokenizer.replace_word_delimiter_char
|
||||||
)
|
)
|
||||||
elif return_timestamps == "word":
|
|
||||||
decoded = self.tokenizer.decode(
|
if return_timestamps:
|
||||||
items, skip_special_tokens=skip_special_tokens, output_word_offsets=True
|
if return_timestamps == "word":
|
||||||
)
|
offsets = word_offsets
|
||||||
|
else:
|
||||||
|
offsets = char_offsets
|
||||||
chunks = []
|
chunks = []
|
||||||
for item in decoded[f"{return_timestamps}_offsets"]:
|
for item in offsets:
|
||||||
start = (
|
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
|
||||||
item["start_offset"]
|
start /= self.feature_extractor.sampling_rate
|
||||||
* self.model.config.inputs_to_logits_ratio
|
|
||||||
/ self.feature_extractor.sampling_rate
|
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
|
||||||
)
|
stop /= self.feature_extractor.sampling_rate
|
||||||
stop = (
|
|
||||||
item["end_offset"]
|
|
||||||
* self.model.config.inputs_to_logits_ratio
|
|
||||||
/ self.feature_extractor.sampling_rate
|
|
||||||
)
|
|
||||||
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
||||||
optional["chunks"] = chunks
|
optional["chunks"] = chunks
|
||||||
|
|
||||||
|
|||||||
@@ -188,6 +188,32 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
speech_recognizer.type = "ctc_with_lm"
|
||||||
|
# Simple test with CTC with LM, chunking + timestamps
|
||||||
|
output = speech_recognizer(filename, chunk_length_s=2.0, return_timestamps="word")
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{
|
||||||
|
"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajcri",
|
||||||
|
"chunks": [
|
||||||
|
{"text": "y", "timestamp": (0.52, 0.54)},
|
||||||
|
{"text": "en", "timestamp": (0.6, 0.68)},
|
||||||
|
{"text": "las", "timestamp": (0.74, 0.84)},
|
||||||
|
{"text": "ramas", "timestamp": (0.94, 1.24)},
|
||||||
|
{"text": "medio", "timestamp": (1.32, 1.52)},
|
||||||
|
{"text": "sumergidas", "timestamp": (1.56, 2.22)},
|
||||||
|
{"text": "revoloteaban", "timestamp": (2.36, 3.0)},
|
||||||
|
{"text": "algunos", "timestamp": (3.06, 3.38)},
|
||||||
|
{"text": "pájaros", "timestamp": (3.46, 3.86)},
|
||||||
|
{"text": "de", "timestamp": (3.92, 4.0)},
|
||||||
|
{"text": "quimérico", "timestamp": (4.08, 4.6)},
|
||||||
|
{"text": "y", "timestamp": (4.66, 4.68)},
|
||||||
|
{"text": "legendario", "timestamp": (4.74, 5.26)},
|
||||||
|
{"text": "plumajcri", "timestamp": (5.34, 5.74)},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
self.skipTest("Tensorflow not supported yet.")
|
self.skipTest("Tensorflow not supported yet.")
|
||||||
|
|||||||
Reference in New Issue
Block a user