Add Flash Attention 2 to M2M100 model (#30256)
* Added flash attention 2. * Fixes. * Fix inheritance. * Fixed init. * Remove stuff. * Added documentation. * Add FA2 to M2M100 documentation. * Add test. * Fixed documentation. * Update src/transformers/models/m2m_100/modeling_m2m_100.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/en/model_doc/nllb.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixed variable name. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ec92f983af
commit
b65df514d1
@@ -19,12 +19,16 @@ import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import M2M100Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -412,3 +416,48 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||
)
|
||||
assert generated == expected_en
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_seq_to_seq_generation(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
"facebook/m2m100_418M", attn_implementation="flash_attention_2"
|
||||
).to(torch_device)
|
||||
|
||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
|
||||
|
||||
src_fr = [
|
||||
"L'affaire NSA souligne l'absence totale de débat sur le renseignement",
|
||||
"Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
|
||||
"Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
|
||||
" Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
|
||||
" l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
|
||||
]
|
||||
|
||||
# The below article tests that we don't add any hypotheses outside of the top n_beams
|
||||
dct = tokenizer(src_fr, padding=True, return_tensors="pt")
|
||||
|
||||
hypotheses_batch = model.generate(
|
||||
input_ids=dct["input_ids"].to(torch_device),
|
||||
attention_mask=dct["attention_mask"].to(torch_device),
|
||||
num_beams=5,
|
||||
forced_bos_token_id=tokenizer.get_lang_id("en"),
|
||||
)
|
||||
|
||||
expected_en = [
|
||||
"The NSA case highlights the total absence of intelligence debate",
|
||||
"I think there are two levels of response from the French government.",
|
||||
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
|
||||
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
|
||||
" communications in France.",
|
||||
]
|
||||
|
||||
generated = tokenizer.batch_decode(
|
||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||
)
|
||||
assert generated == expected_en
|
||||
|
||||
Reference in New Issue
Block a user