From 5f1fcc299cb00c1edce5eb1efb8bacdde2365690 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:51:04 +0100 Subject: [PATCH] [Idefics2] - Fix FA2 call for Perceiver layer (#32275) * Fix FA2 call for Perciever layer * [run_slow] idefics2 * [run_slow] idefics2 * [run_slow] idefics2 * Fix up * [run_slow] idefics2 * [run_slow] idefics2 * [run_slow] idefics2 --- .../models/idefics2/modeling_idefics2.py | 2 +- .../models/idefics2/test_modeling_idefics2.py | 48 +++++++++++++++++-- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 7c80ddfe92..089520ae54 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -894,7 +894,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): attention_mask, q_len, dropout=dropout_rate, - sliding_window=False, + sliding_window=None, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 057ce93cd8..d1a982dc59 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -29,7 +29,14 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device +from transformers.testing_utils import ( + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -491,13 +498,13 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): torch.cuda.empty_cache() @slow + @unittest.skip("Test hits OOM on CI - https://github.com/huggingface/transformers/issues/32288") def test_integration_test(self): model = Idefics2ForConditionalGeneration.from_pretrained( "HuggingFaceM4/idefics2-8b-base", torch_dtype=torch.bfloat16, device_map="auto", ) - model.to(torch_device) # Create inputs text = "In this image, we see" @@ -517,7 +524,8 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): def test_integration_test_4bit(self): # Let' s make sure we test the preprocessing to replace what is used model = Idefics2ForConditionalGeneration.from_pretrained( - "HuggingFaceM4/idefics2-8b-base", load_in_4bit=True, device_map="auto" + "HuggingFaceM4/idefics2-8b-base", + load_in_4bit=True, ) # Create pixel inputs @@ -530,3 +538,37 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River," self.assertEqual(generated_texts[0], expected_generated_text) + + @require_flash_attn + @require_torch_gpu + @require_bitsandbytes + def test_flash_attn_2_eager_equivalence(self): + # Create inputs + text = "In this image, we see" + images = self.image1 + inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True) + inputs.to(torch_device) + + # Eager model + model_eager = Idefics2ForConditionalGeneration.from_pretrained( + "HuggingFaceM4/idefics2-8b-base", + attn_implementation="eager", + load_in_4bit=True, + ) + generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10) + generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True) + + del model_eager + + # Flash Attention 2 model + model_flash_attention_2 = Idefics2ForConditionalGeneration.from_pretrained( + "HuggingFaceM4/idefics2-8b-base", + attn_implementation="flash_attention_2", + load_in_4bit=True, + ) + generated_ids_flash_attention_2 = model_flash_attention_2.generate(**inputs, max_new_tokens=10) + generated_texts_flash_attention_2 = self.processor.batch_decode( + generated_ids_flash_attention_2, skip_special_tokens=True + ) + + self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])