Fix languages covered by M4Tv2 (#28019)
* correct language assessment + add tests * Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style + simplify and enrich test --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -758,7 +758,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
def update_generation(self, model):
|
||||
lang_code_to_id = {
|
||||
text_lang_code_to_id = {
|
||||
"fra": 4,
|
||||
"eng": 4,
|
||||
"rus": 4,
|
||||
}
|
||||
|
||||
speech_lang_code_to_id = {
|
||||
"fra": 4,
|
||||
"eng": 4,
|
||||
}
|
||||
@@ -773,9 +779,9 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
||||
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
|
||||
generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id)
|
||||
generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id)
|
||||
generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id)
|
||||
generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id)
|
||||
generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id)
|
||||
generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id)
|
||||
generation_config.__setattr__("id_to_text", id_to_text)
|
||||
generation_config.__setattr__("char_to_id", char_to_id)
|
||||
generation_config.__setattr__("eos_token_id", 0)
|
||||
@@ -784,13 +790,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
||||
|
||||
model.generation_config = generation_config
|
||||
|
||||
def prepare_text_input(self):
|
||||
def prepare_text_input(self, tgt_lang):
|
||||
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()
|
||||
|
||||
input_dict = {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": input_mask,
|
||||
"tgt_lang": "eng",
|
||||
"tgt_lang": tgt_lang,
|
||||
"num_beams": 2,
|
||||
"do_sample": True,
|
||||
}
|
||||
@@ -837,6 +843,26 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs)
|
||||
return output
|
||||
|
||||
def test_generation_languages(self):
|
||||
config, input_text_rus = self.prepare_text_input(tgt_lang="rus")
|
||||
|
||||
model = SeamlessM4Tv2Model(config=config)
|
||||
self.update_generation(model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# make sure that generating speech, with a language that is only supported for text translation, raises error
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(**input_text_rus)
|
||||
|
||||
# make sure that generating text only works
|
||||
model.generate(**input_text_rus, generate_speech=False)
|
||||
|
||||
# make sure it works for languages supported by both output modalities
|
||||
config, input_text_eng = self.prepare_text_input(tgt_lang="eng")
|
||||
model.generate(**input_text_eng)
|
||||
model.generate(**input_text_eng, generate_speech=False)
|
||||
|
||||
def test_speech_generation(self):
|
||||
config, input_speech, input_text = self.prepare_speech_and_text_input()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user