[MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 (#23813)
* add fine-tuned with adapter layer * Add set_target_lang to tokenizer * Implement load adapter * add tests * make style * Apply suggestions from code review * Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py * make fix-copies * Apply suggestions from code review * make fix-copies * make style again * mkae style again * fix doc string * Update tests/models/wav2vec2/test_tokenization_wav2vec2.py * Apply suggestions from code review * fix * Correct wav2vec2 adapter * mkae style * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * add more nice docs * finish * finish * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Apply suggestions from code review * all finish --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f49a3453ca
commit
5dfd407b37
@@ -772,3 +772,48 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
output = tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
self.assertIsInstance(output["text"], str)
|
||||
|
||||
def test_nested_vocab(self):
|
||||
eng_vocab = {"a": 7, "b": 8}
|
||||
spa_vocab = {"a": 23, "c": 88}
|
||||
ita_vocab = {"a": 6, "d": 9}
|
||||
|
||||
nested_vocab = {"eng": eng_vocab, "spa": spa_vocab, "ita": ita_vocab}
|
||||
|
||||
def check_tokenizer(tokenizer, check_ita_first=False):
|
||||
if check_ita_first:
|
||||
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
|
||||
self.assertEqual(tokenizer.encoder, ita_vocab)
|
||||
tokenizer.set_target_lang("eng")
|
||||
|
||||
self.assertEqual(tokenizer.encoder, eng_vocab)
|
||||
self.assertEqual(tokenizer.decode([7, 8, 7]), "aba")
|
||||
|
||||
tokenizer.set_target_lang("spa")
|
||||
self.assertEqual(tokenizer.decode([23, 88, 23]), "aca")
|
||||
self.assertEqual(tokenizer.encoder, spa_vocab)
|
||||
|
||||
tokenizer.set_target_lang("eng")
|
||||
self.assertEqual(tokenizer.encoder, eng_vocab)
|
||||
self.assertEqual(tokenizer.decode([7, 7, 8]), "ab")
|
||||
|
||||
tokenizer.set_target_lang("ita")
|
||||
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
|
||||
self.assertEqual(tokenizer.encoder, ita_vocab)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
tempfile_path = os.path.join(tempdir, "vocab.json")
|
||||
with open(tempfile_path, "w") as temp_file:
|
||||
json.dump(nested_vocab, temp_file)
|
||||
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir, target_lang="eng")
|
||||
|
||||
check_tokenizer(tokenizer)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
# should have saved target lang as "ita" since it was last one
|
||||
tokenizer.save_pretrained(tempdir)
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir)
|
||||
|
||||
self.assertEqual(tokenizer.target_lang, "ita")
|
||||
check_tokenizer(tokenizer, check_ita_first=True)
|
||||
|
||||
Reference in New Issue
Block a user