Add SeamlessM4T v2 (#27779)
* add working convertion script * first non-working version of modeling code * update modeling code (working) * make style * make fix-copies * add config docstrings * add config to ignore docstrings formatage due to unconventional markdown * fix copies * fix generation num_return_sequences * enrich docs * add and fix tests beside integration tests * update integration tests * update repo id * add tie weights and make style * correct naming in .md * fix imports and so on * correct docstrings * fix fp16 speech forward * fix speechencoder attention * make style * fix copied from * rename SeamlessM4Tv2-v2 to SeamlessM4Tv2 * Apply suggestions on configuration Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * remove useless public models * fix private models + better naming for T2U models * clean speech encoder relative position embeddings * refactor chunk attention * add docstrings to chunk attention method * improve naming and docstrings * rename some attention variables + add temperature sampling in T2U model * rename DOCSTRINGS variable names * make style + remove 2 useless config parameters * enrich model card * remove any attention_head reference + fix temperature in T2U * new fmt and make style * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * rename spkr_id->speaker_id and change docstrings of get_char_input_ids * simplify v2attention * make style * Update seamless_m4t_v2.md * update code and tests with last update * update repo ids * fill article name, abstract andauthors * update not_doctested and slow_doc tests --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -479,10 +478,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The speech encoder doesn't support head masking")
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SeamlessM4TModel can takes input_ids or input_features")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
@@ -714,43 +709,6 @@ class SeamlessM4TModelWithTextInputTest(
|
||||
def test_model_weights_reload_no_missing_tied_weights(self):
|
||||
pass
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
head_masking = {
|
||||
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device),
|
||||
"decoder_head_mask": torch.zeros(
|
||||
config.decoder_layers, config.decoder_attention_heads, device=torch_device
|
||||
),
|
||||
"cross_attn_head_mask": torch.zeros(
|
||||
config.decoder_layers, config.decoder_attention_heads, device=torch_device
|
||||
),
|
||||
}
|
||||
|
||||
signature = inspect.signature(model.forward)
|
||||
# We want to test only models where encoder/decoder head masking is implemented
|
||||
if not set(head_masking.keys()) < {*signature.parameters.keys()}:
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=1,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip(reason="SeamlessM4TModel can take input_ids or input_features")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user