From bb1d0d0d9e7ca356cf5673031183e955cc160158 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:43:44 +0000 Subject: [PATCH] Fix languages covered by M4Tv2 (#28019) * correct language assessment + add tests * Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style + simplify and enrich test --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../modeling_seamless_m4t_v2.py | 6 ++- .../test_modeling_seamless_m4t_v2.py | 38 ++++++++++++++++--- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index f1a26b3e5b..bceb1b4946 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -4596,7 +4596,11 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel): if tgt_lang is not None: # also accept __xxx__ tgt_lang = tgt_lang.replace("__", "") - for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + if generate_speech: + keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] + else: + keys_to_check = ["text_decoder_lang_to_code_id"] + for key in keys_to_check: lang_code_to_id = getattr(self.generation_config, key, None) if lang_code_to_id is None: raise ValueError( diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 8627220c71..795f3d8042 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -758,7 +758,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): self.tmpdirname = tempfile.mkdtemp() def update_generation(self, model): - lang_code_to_id = { + text_lang_code_to_id = { + "fra": 4, + "eng": 4, + "rus": 4, + } + + speech_lang_code_to_id = { "fra": 4, "eng": 4, } @@ -773,9 +779,9 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): generation_config = copy.deepcopy(model.generation_config) - generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id) - generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id) - generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id) + generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id) + generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id) + generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id) generation_config.__setattr__("id_to_text", id_to_text) generation_config.__setattr__("char_to_id", char_to_id) generation_config.__setattr__("eos_token_id", 0) @@ -784,13 +790,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): model.generation_config = generation_config - def prepare_text_input(self): + def prepare_text_input(self, tgt_lang): config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() input_dict = { "input_ids": inputs, "attention_mask": input_mask, - "tgt_lang": "eng", + "tgt_lang": tgt_lang, "num_beams": 2, "do_sample": True, } @@ -837,6 +843,26 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): output = model.generate(**inputs) return output + def test_generation_languages(self): + config, input_text_rus = self.prepare_text_input(tgt_lang="rus") + + model = SeamlessM4Tv2Model(config=config) + self.update_generation(model) + model.to(torch_device) + model.eval() + + # make sure that generating speech, with a language that is only supported for text translation, raises error + with self.assertRaises(ValueError): + model.generate(**input_text_rus) + + # make sure that generating text only works + model.generate(**input_text_rus, generate_speech=False) + + # make sure it works for languages supported by both output modalities + config, input_text_eng = self.prepare_text_input(tgt_lang="eng") + model.generate(**input_text_eng) + model.generate(**input_text_eng, generate_speech=False) + def test_speech_generation(self): config, input_speech, input_text = self.prepare_speech_and_text_input()