Add Flash Attention 2 to M2M100 model (#30256)

* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/en/model_doc/nllb.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fixed variable name.

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Alexander Visheratin
2024-04-18 04:27:58 -04:00
committed by GitHub
parent ec92f983af
commit b65df514d1
5 changed files with 379 additions and 14 deletions

View File

@@ -19,12 +19,16 @@ import copy
import tempfile
import unittest
import pytest
from transformers import M2M100Config, is_torch_available
from transformers.testing_utils import (
require_flash_attn,
require_sentencepiece,
require_tokenizers,
require_torch,
require_torch_fp16,
require_torch_gpu,
slow,
torch_device,
)
@@ -412,3 +416,48 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated == expected_en
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_seq_to_seq_generation(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = M2M100ForConditionalGeneration.from_pretrained(
"facebook/m2m100_418M", attn_implementation="flash_attention_2"
).to(torch_device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
src_fr = [
"L'affaire NSA souligne l'absence totale de débat sur le renseignement",
"Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
"Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
" Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
" l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
]
# The below article tests that we don't add any hypotheses outside of the top n_beams
dct = tokenizer(src_fr, padding=True, return_tensors="pt")
hypotheses_batch = model.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=5,
forced_bos_token_id=tokenizer.get_lang_id("en"),
)
expected_en = [
"The NSA case highlights the total absence of intelligence debate",
"I think there are two levels of response from the French government.",
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
" communications in France.",
]
generated = tokenizer.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated == expected_en