Time stamps for CTC models (#15687)
* [Wav2Vec2 Time Stamps] * Add first version * add word time stamps * Fix * save intermediate space * improve * [Finish CTC Tokenizer] * remove @ * remove @ * push * continue with phonemes * up * finish PR * up * add example * rename * finish * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct split * finalize Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
32295b15a1
commit
c44d3675c2
@@ -29,7 +29,7 @@ from transformers import (
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2Tokenizer,
|
||||
)
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizerOutput
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -422,27 +422,16 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_tokenizer_decode_special(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
sample_ids_2 = [
|
||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
77,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
],
|
||||
[24, 22, 5, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.word_delimiter_token_id],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||
@@ -454,27 +443,12 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer.add_tokens(["!", "?"])
|
||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[
|
||||
11,
|
||||
5,
|
||||
15,
|
||||
tokenizer.pad_token_id,
|
||||
15,
|
||||
8,
|
||||
98,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
34,
|
||||
],
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
@@ -499,6 +473,187 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# fmt: off
|
||||
# HEEEEE||LLL<pad>LO<unk> => HE LLO<unk>
|
||||
# 1H + 5E + 2| + 3L + 1<pad> + 1L + 1O + 1<unk>
|
||||
sample_ids = [11, 5, 5, 5, 5, 5, 4, 4, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98]
|
||||
# fmt: on
|
||||
|
||||
outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||
self.assertTrue(len(outputs_char.keys()), 2)
|
||||
self.assertTrue("text" in outputs_char)
|
||||
self.assertTrue("char_offsets" in outputs_char)
|
||||
self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs_word.keys()), 2)
|
||||
self.assertTrue("text" in outputs_word)
|
||||
self.assertTrue("word_offsets" in outputs_word)
|
||||
self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for both
|
||||
self.assertTrue(len(outputs.keys()), 3)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("char_offsets" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
# check that order of chars is correct and identical for both outputs
|
||||
self.assertEqual("".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||
self.assertEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"), ["H", "E", " ", "L", "L", "O", "<unk>"]
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"),
|
||||
self.get_from_offsets(outputs_char["char_offsets"], "char"),
|
||||
)
|
||||
|
||||
# check that order of words is correct and identical to both outputs
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["HE", "LLO<unk>"])
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["word_offsets"], "word"),
|
||||
self.get_from_offsets(outputs_word["word_offsets"], "word"),
|
||||
)
|
||||
|
||||
# check that offsets are actually correct for char
|
||||
# 0 is H, 1 is E, 6 is | (" "), 8 is 1st L, 12 is 2nd L, 13 is O, 14 is <unk>
|
||||
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 6, 8, 12, 13, 14])
|
||||
# 1 is H, 6 is E, 8 is | (" "), 11 is 1st L (note due to <pad>
|
||||
# different begin of 2nd L), 13 is 2nd L, 14 is O, 15 is <unk>
|
||||
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 6, 8, 11, 13, 14, 15])
|
||||
|
||||
# check that offsets are actually correct for word
|
||||
# H is at 1st position of first word, first L is at 8th position of second word
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 8])
|
||||
# last E is at 6th position of first word, first L is at last (15th) position of second word
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [6, 15])
|
||||
|
||||
def test_offsets_batch(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||
self.assertTrue(isinstance(outputs_batch, Wav2Vec2CTCTokenizerOutput))
|
||||
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
# transform list to ModelOutput
|
||||
outputs_batch_2 = Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in outputs_list] for k in outputs_list[0]})
|
||||
|
||||
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||
|
||||
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||
if isinstance(list_or_dict_1, list):
|
||||
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||
|
||||
if "char_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||
|
||||
if "word_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["word_offsets"], outputs_batch_2["word_offsets"])
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# We assume that `decode` works as expected. All we will check now is
|
||||
# the output type is correct and the output is identical to `decode`
|
||||
|
||||
# char
|
||||
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||
|
||||
# word
|
||||
outputs_word_batch = tokenizer.batch_decode(sample_ids, output_word_offsets=True)
|
||||
outputs_word = [tokenizer.decode(ids, output_word_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_word_batch, outputs_word)
|
||||
|
||||
# both
|
||||
outputs_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
outputs = [tokenizer.decode(ids, output_word_offsets=True, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_batch, outputs)
|
||||
|
||||
def test_offsets_integration(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
# pred_ids correspond to the following code
|
||||
# ```
|
||||
# from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||
# from datasets import load_dataset
|
||||
# import datasets
|
||||
# import torch
|
||||
# model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
# feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
#
|
||||
# ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
# ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
# ds_iter = iter(ds)
|
||||
# sample = next(ds_iter)
|
||||
#
|
||||
# input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
# logits = model(input_values).logits
|
||||
# pred_ids = torch.argmax(logits, axis=-1).cpu().tolist()
|
||||
# ```
|
||||
# fmt: off
|
||||
pred_ids = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 11, 0, 0, 0, 22, 0, 0, 4, 4, 4, 14, 0, 0, 0, 0, 0, 8, 8, 0, 5, 5, 0, 12, 0, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 10, 0, 0, 0, 15, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 7, 0, 9, 0, 0, 14, 0, 0, 0, 13, 0, 7, 0, 0, 4, 4, 0, 15, 8, 8, 0, 0, 8, 0, 26, 0, 0, 4, 4, 0, 0, 15, 0, 0, 0, 0, 0, 0, 10, 0, 26, 5, 5, 0, 4, 4, 0, 0, 12, 11, 0, 0, 5, 4, 4, 4, 0, 18, 0, 0, 0, 7, 9, 9, 0, 6, 0, 12, 12, 4, 4, 0, 6, 0, 0, 8, 0, 4, 4, 4, 0, 19, 0, 0, 8, 9, 9, 0, 0, 0, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 0, 17, 5, 5, 5, 0, 4, 4, 4, 0, 0, 29, 29, 0, 0, 0, 0, 8, 11, 0, 9, 9, 0, 0, 0, 4, 4, 0, 12, 12, 0, 0, 0, 9, 0, 0, 0, 0, 0, 8, 18, 0, 0, 0, 4, 4, 0, 0, 8, 9, 0, 4, 4, 0, 6, 11, 5, 0, 4, 4, 0, 13, 13, 0, 0, 0, 10, 0, 0, 25, 0, 0, 6, 0, 4, 4, 0, 0, 0, 0, 7, 0, 0, 23, 0, 0, 4, 4, 0, 0, 0, 6, 11, 0, 5, 4, 4, 18, 0, 0, 0, 0, 0, 0, 7, 15, 0, 0, 0, 15, 15, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
|
||||
# wav2vec2-base downsamples input audio by a factor of 320
|
||||
# sampling rate for wav2vec2-base is 16_000
|
||||
time_offset_wav2vec2_base = 320 / 16_000
|
||||
|
||||
expected_char_time_stamps_text = ['W', 'H', 'Y', ' ', 'D', 'O', 'E', 'S', ' ', 'M', 'I', 'L', 'I', 'S', 'A', 'N', 'D', 'R', 'A', ' ', 'L', 'O', 'O', 'K', ' ', 'L', 'I', 'K', 'E', ' ', 'S', 'H', 'E', ' ', 'W', 'A', 'N', 'T', 'S', ' ', 'T', 'O', ' ', 'C', 'O', 'N', 'S', 'U', 'M', 'E', ' ', 'J', 'O', 'H', 'N', ' ', 'S', 'N', 'O', 'W', ' ', 'O', 'N', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'T', ' ', 'A', 'P', ' ', 'T', 'H', 'E', ' ', 'W', 'A', 'L', 'L', ' ']
|
||||
expected_char_time_stamps_start = [1.42, 1.44, 1.52, 1.58, 1.64, 1.76, 1.82, 1.88, 1.92, 2.26, 2.32, 2.4, 2.46, 2.54, 2.66, 2.7, 2.76, 2.84, 2.88, 2.94, 3.0, 3.02, 3.1, 3.14, 3.2, 3.28, 3.42, 3.46, 3.48, 3.54, 3.62, 3.64, 3.7, 3.72, 3.8, 3.88, 3.9, 3.96, 4.0, 4.04, 4.1, 4.16, 4.2, 4.28, 4.34, 4.36, 4.48, 4.66, 4.74, 4.76, 4.84, 4.94, 5.06, 5.08, 5.12, 5.22, 5.28, 5.38, 5.5, 5.52, 5.6, 5.68, 5.7, 5.74, 5.8, 5.82, 5.84, 5.88, 5.94, 6.04, 6.1, 6.16, 6.2, 6.32, 6.38, 6.44, 6.54, 6.56, 6.6, 6.62, 6.66, 6.8, 6.82, 6.9, 6.96]
|
||||
expected_char_time_stamps_end = [1.44, 1.46, 1.54, 1.64, 1.66, 1.8, 1.86, 1.9, 2.06, 2.28, 2.34, 2.42, 2.48, 2.56, 2.68, 2.72, 2.78, 2.86, 2.9, 2.98, 3.02, 3.06, 3.12, 3.16, 3.24, 3.3, 3.44, 3.48, 3.52, 3.58, 3.64, 3.66, 3.72, 3.78, 3.82, 3.9, 3.94, 3.98, 4.04, 4.08, 4.12, 4.18, 4.26, 4.3, 4.36, 4.4, 4.52, 4.7, 4.76, 4.82, 4.9, 4.98, 5.08, 5.1, 5.16, 5.26, 5.32, 5.4, 5.52, 5.54, 5.64, 5.7, 5.72, 5.78, 5.82, 5.84, 5.86, 5.92, 5.98, 6.06, 6.12, 6.18, 6.24, 6.34, 6.4, 6.48, 6.56, 6.58, 6.62, 6.66, 6.68, 6.82, 6.84, 6.94, 7.02]
|
||||
|
||||
expected_word_time_stamps_text = ['WHY', 'DOES', 'MILISANDRA', 'LOOK', 'LIKE', 'SHE', 'WANTS', 'TO', 'CONSUME', 'JOHN', 'SNOW', 'ON', 'THE', 'RIVT', 'AP', 'THE', 'WALL']
|
||||
expected_word_time_stamps_start = [1.42, 1.64, 2.26, 3.0, 3.28, 3.62, 3.8, 4.1, 4.28, 4.94, 5.28, 5.68, 5.8, 5.94, 6.32, 6.54, 6.66]
|
||||
expected_word_time_stamps_end = [1.54, 1.9, 2.9, 3.16, 3.52, 3.72, 4.04, 4.18, 4.82, 5.16, 5.54, 5.72, 5.86, 6.18, 6.4, 6.62, 6.94]
|
||||
# fmt: on
|
||||
|
||||
output = tokenizer.batch_decode(pred_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
|
||||
char_offsets_text = self.get_from_offsets(output["char_offsets"][0], "char")
|
||||
char_offsets_start = self.get_from_offsets(output["char_offsets"][0], "start_offset")
|
||||
char_offsets_end = self.get_from_offsets(output["char_offsets"][0], "end_offset")
|
||||
|
||||
word_offsets_text = self.get_from_offsets(output["word_offsets"][0], "word")
|
||||
word_offsets_start = self.get_from_offsets(output["word_offsets"][0], "start_offset")
|
||||
word_offsets_end = self.get_from_offsets(output["word_offsets"][0], "end_offset")
|
||||
|
||||
# let's transform offsets to time stamps in seconds
|
||||
char_time_stamps_start = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_start]
|
||||
char_time_stamps_end = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_end]
|
||||
|
||||
word_time_stamps_start = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_start]
|
||||
word_time_stamps_end = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_end]
|
||||
|
||||
# NOTE: you can verify the above results by checking out the dataset viewer
|
||||
# on https://huggingface.co/datasets/common_voice/viewer/en/train and
|
||||
# downloading / playing the sample `common_voice_en_100038.mp3`. As
|
||||
# you can hear the time-stamps match more or less
|
||||
|
||||
self.assertListEqual(expected_char_time_stamps_text, char_offsets_text)
|
||||
self.assertListEqual(expected_char_time_stamps_start, char_time_stamps_start)
|
||||
self.assertListEqual(expected_char_time_stamps_end, char_time_stamps_end)
|
||||
|
||||
self.assertListEqual(expected_word_time_stamps_text, word_offsets_text)
|
||||
self.assertListEqual(expected_word_time_stamps_start, word_time_stamps_start)
|
||||
self.assertListEqual(expected_word_time_stamps_end, word_time_stamps_end)
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Tuple
|
||||
|
||||
from transformers import Wav2Vec2PhonemeCTCTokenizer
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.models.wav2vec2_phoneme.tokenization_wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizerOutput
|
||||
from transformers.testing_utils import require_phonemizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -248,23 +249,94 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"])
|
||||
|
||||
# overwrite common test
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets(self):
|
||||
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
||||
tokenizer.add_tokens("|")
|
||||
|
||||
# fmt: off
|
||||
# ksssɾɾ|ɾɾ<pad>ɾɾ|<pad>ɾlll|ɭʲ -> k s ɾ ɾ | ɾ l | ɭʲ"
|
||||
sample_ids = [11, 5, 5, 5, 15, 15, tokenizer.pad_token_id, 15, 15, tokenizer.word_delimiter_token_id, tokenizer.pad_token_id, 15, 8, 8, 8, tokenizer.word_delimiter_token_id, 98]
|
||||
# fmt: on
|
||||
|
||||
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||
self.assertTrue(len(outputs.keys()), 2)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("char_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||
|
||||
# check that order of chars is correct and identical for both outputs
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"), ["k", "s", "ɾ", "ɾ", "|", "ɾ", "l", "|", "ɭʲ"]
|
||||
)
|
||||
|
||||
# check that offsets are actually correct for char
|
||||
# 0-1 is 11, 1-4 is 5, 4-6 is first 15, 6-7 is <pad> (thus not shown), 7-9 is second 15, 9-10 is word_delimiter_token,
|
||||
# 10-11 is <pad> (thus not shown), 11-12 is third 15, 12-15 is 8, 15-16 is word_delimiter_token, 16-17 is 98
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 4, 7, 9, 11, 12, 15, 16]
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 4, 6, 9, 10, 12, 15, 16, 17]
|
||||
)
|
||||
|
||||
def test_offsets_batch(self):
|
||||
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
||||
|
||||
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||
self.assertTrue(isinstance(outputs_batch, Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||
|
||||
# transform list to ModelOutput
|
||||
outputs_batch_2 = Wav2Vec2PhonemeCTCTokenizerOutput(
|
||||
{k: [d[k] for d in outputs_list] for k in outputs_list[0]}
|
||||
)
|
||||
|
||||
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||
|
||||
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||
if isinstance(list_or_dict_1, list):
|
||||
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||
|
||||
if "char_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# We assume that `decode` works as expected. All we will check now is
|
||||
# the output type is correct and the output is identical to `decode`
|
||||
|
||||
# char
|
||||
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||
|
||||
@unittest.skip("Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes")
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
# Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes
|
||||
pass
|
||||
|
||||
# overwrite common test
|
||||
@unittest.skip("Wav2Vec2PhonemeTokenizer always puts spaces between phonemes")
|
||||
def test_encode_decode_with_spaces(self):
|
||||
# Wav2Vec2PhonemeTokenizer always puts spaces between phonemes
|
||||
pass
|
||||
|
||||
# overwrite common test
|
||||
@unittest.skip("encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency")
|
||||
def test_internal_consistency(self):
|
||||
# encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency
|
||||
pass
|
||||
|
||||
@unittest.skip("Wav2Vec2PhonemeModel has no max model length => no testing")
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2PhonemeModel has no max model length => no testing
|
||||
pass
|
||||
|
||||
# overwrite common
|
||||
|
||||
Reference in New Issue
Block a user