[Idefics2] - Fix FA2 call for Perceiver layer (#32275)
* Fix FA2 call for Perciever layer * [run_slow] idefics2 * [run_slow] idefics2 * [run_slow] idefics2 * Fix up * [run_slow] idefics2 * [run_slow] idefics2 * [run_slow] idefics2
This commit is contained in:
@@ -894,7 +894,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
q_len,
|
q_len,
|
||||||
dropout=dropout_rate,
|
dropout=dropout_rate,
|
||||||
sliding_window=False,
|
sliding_window=None,
|
||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,7 +29,14 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_flash_attn,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -491,13 +498,13 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@unittest.skip("Test hits OOM on CI - https://github.com/huggingface/transformers/issues/32288")
|
||||||
def test_integration_test(self):
|
def test_integration_test(self):
|
||||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||||
"HuggingFaceM4/idefics2-8b-base",
|
"HuggingFaceM4/idefics2-8b-base",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
)
|
)
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
# Create inputs
|
# Create inputs
|
||||||
text = "<image>In this image, we see"
|
text = "<image>In this image, we see"
|
||||||
@@ -517,7 +524,8 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
def test_integration_test_4bit(self):
|
def test_integration_test_4bit(self):
|
||||||
# Let' s make sure we test the preprocessing to replace what is used
|
# Let' s make sure we test the preprocessing to replace what is used
|
||||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||||
"HuggingFaceM4/idefics2-8b-base", load_in_4bit=True, device_map="auto"
|
"HuggingFaceM4/idefics2-8b-base",
|
||||||
|
load_in_4bit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create pixel inputs
|
# Create pixel inputs
|
||||||
@@ -530,3 +538,37 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
|
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
|
||||||
self.assertEqual(generated_texts[0], expected_generated_text)
|
self.assertEqual(generated_texts[0], expected_generated_text)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_flash_attn_2_eager_equivalence(self):
|
||||||
|
# Create inputs
|
||||||
|
text = "<image>In this image, we see"
|
||||||
|
images = self.image1
|
||||||
|
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
|
||||||
|
inputs.to(torch_device)
|
||||||
|
|
||||||
|
# Eager model
|
||||||
|
model_eager = Idefics2ForConditionalGeneration.from_pretrained(
|
||||||
|
"HuggingFaceM4/idefics2-8b-base",
|
||||||
|
attn_implementation="eager",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
|
||||||
|
generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)
|
||||||
|
|
||||||
|
del model_eager
|
||||||
|
|
||||||
|
# Flash Attention 2 model
|
||||||
|
model_flash_attention_2 = Idefics2ForConditionalGeneration.from_pretrained(
|
||||||
|
"HuggingFaceM4/idefics2-8b-base",
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
generated_ids_flash_attention_2 = model_flash_attention_2.generate(**inputs, max_new_tokens=10)
|
||||||
|
generated_texts_flash_attention_2 = self.processor.batch_decode(
|
||||||
|
generated_ids_flash_attention_2, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])
|
||||||
|
|||||||
Reference in New Issue
Block a user