Fix languages covered by M4Tv2 (#28019)
* correct language assessment + add tests * Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style + simplify and enrich test --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -4596,7 +4596,11 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel):
|
|||||||
if tgt_lang is not None:
|
if tgt_lang is not None:
|
||||||
# also accept __xxx__
|
# also accept __xxx__
|
||||||
tgt_lang = tgt_lang.replace("__", "")
|
tgt_lang = tgt_lang.replace("__", "")
|
||||||
for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]:
|
if generate_speech:
|
||||||
|
keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]
|
||||||
|
else:
|
||||||
|
keys_to_check = ["text_decoder_lang_to_code_id"]
|
||||||
|
for key in keys_to_check:
|
||||||
lang_code_to_id = getattr(self.generation_config, key, None)
|
lang_code_to_id = getattr(self.generation_config, key, None)
|
||||||
if lang_code_to_id is None:
|
if lang_code_to_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -758,7 +758,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
|||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
def update_generation(self, model):
|
def update_generation(self, model):
|
||||||
lang_code_to_id = {
|
text_lang_code_to_id = {
|
||||||
|
"fra": 4,
|
||||||
|
"eng": 4,
|
||||||
|
"rus": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
speech_lang_code_to_id = {
|
||||||
"fra": 4,
|
"fra": 4,
|
||||||
"eng": 4,
|
"eng": 4,
|
||||||
}
|
}
|
||||||
@@ -773,9 +779,9 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
generation_config = copy.deepcopy(model.generation_config)
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
|
|
||||||
generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id)
|
generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id)
|
||||||
generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id)
|
generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id)
|
||||||
generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id)
|
generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id)
|
||||||
generation_config.__setattr__("id_to_text", id_to_text)
|
generation_config.__setattr__("id_to_text", id_to_text)
|
||||||
generation_config.__setattr__("char_to_id", char_to_id)
|
generation_config.__setattr__("char_to_id", char_to_id)
|
||||||
generation_config.__setattr__("eos_token_id", 0)
|
generation_config.__setattr__("eos_token_id", 0)
|
||||||
@@ -784,13 +790,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
model.generation_config = generation_config
|
model.generation_config = generation_config
|
||||||
|
|
||||||
def prepare_text_input(self):
|
def prepare_text_input(self, tgt_lang):
|
||||||
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()
|
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"input_ids": inputs,
|
"input_ids": inputs,
|
||||||
"attention_mask": input_mask,
|
"attention_mask": input_mask,
|
||||||
"tgt_lang": "eng",
|
"tgt_lang": tgt_lang,
|
||||||
"num_beams": 2,
|
"num_beams": 2,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
}
|
}
|
||||||
@@ -837,6 +843,26 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs)
|
output = model.generate(**inputs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def test_generation_languages(self):
|
||||||
|
config, input_text_rus = self.prepare_text_input(tgt_lang="rus")
|
||||||
|
|
||||||
|
model = SeamlessM4Tv2Model(config=config)
|
||||||
|
self.update_generation(model)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# make sure that generating speech, with a language that is only supported for text translation, raises error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.generate(**input_text_rus)
|
||||||
|
|
||||||
|
# make sure that generating text only works
|
||||||
|
model.generate(**input_text_rus, generate_speech=False)
|
||||||
|
|
||||||
|
# make sure it works for languages supported by both output modalities
|
||||||
|
config, input_text_eng = self.prepare_text_input(tgt_lang="eng")
|
||||||
|
model.generate(**input_text_eng)
|
||||||
|
model.generate(**input_text_eng, generate_speech=False)
|
||||||
|
|
||||||
def test_speech_generation(self):
|
def test_speech_generation(self):
|
||||||
config, input_speech, input_text = self.prepare_speech_and_text_input()
|
config, input_speech, input_text = self.prepare_speech_and_text_input()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user