[MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 (#23813)

* add fine-tuned with adapter layer

* Add set_target_lang to tokenizer

* Implement load adapter

* add tests

* make style

* Apply suggestions from code review

* Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py

* make fix-copies

* Apply suggestions from code review

* make fix-copies

* make style again

* mkae style again

* fix doc string

* Update tests/models/wav2vec2/test_tokenization_wav2vec2.py

* Apply suggestions from code review

* fix

* Correct wav2vec2 adapter

* mkae style

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add more nice docs

* finish

* finish

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

* all finish

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2023-06-02 10:30:24 +01:00
committed by GitHub
parent f49a3453ca
commit 5dfd407b37
24 changed files with 823 additions and 33 deletions

View File

@@ -772,3 +772,48 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
output = tokenizer.convert_tokens_to_string(tokens)
self.assertIsInstance(output["text"], str)
def test_nested_vocab(self):
eng_vocab = {"a": 7, "b": 8}
spa_vocab = {"a": 23, "c": 88}
ita_vocab = {"a": 6, "d": 9}
nested_vocab = {"eng": eng_vocab, "spa": spa_vocab, "ita": ita_vocab}
def check_tokenizer(tokenizer, check_ita_first=False):
if check_ita_first:
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
self.assertEqual(tokenizer.encoder, ita_vocab)
tokenizer.set_target_lang("eng")
self.assertEqual(tokenizer.encoder, eng_vocab)
self.assertEqual(tokenizer.decode([7, 8, 7]), "aba")
tokenizer.set_target_lang("spa")
self.assertEqual(tokenizer.decode([23, 88, 23]), "aca")
self.assertEqual(tokenizer.encoder, spa_vocab)
tokenizer.set_target_lang("eng")
self.assertEqual(tokenizer.encoder, eng_vocab)
self.assertEqual(tokenizer.decode([7, 7, 8]), "ab")
tokenizer.set_target_lang("ita")
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
self.assertEqual(tokenizer.encoder, ita_vocab)
with tempfile.TemporaryDirectory() as tempdir:
tempfile_path = os.path.join(tempdir, "vocab.json")
with open(tempfile_path, "w") as temp_file:
json.dump(nested_vocab, temp_file)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir, target_lang="eng")
check_tokenizer(tokenizer)
with tempfile.TemporaryDirectory() as tempdir:
# should have saved target lang as "ita" since it was last one
tokenizer.save_pretrained(tempdir)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir)
self.assertEqual(tokenizer.target_lang, "ita")
check_tokenizer(tokenizer, check_ita_first=True)