From 492ee17ec337573325cc4612a32b36631564b02e Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Wed, 26 Jun 2024 19:17:09 +0200 Subject: [PATCH] Fix paligemma detection inference (#31587) * fix extended attention mask * add slow test for detection instance * [run-slow]paligemma --- .../models/paligemma/modeling_paligemma.py | 5 ++-- .../paligemma/test_modeling_paligemma.py | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 7839f4f56a..9f5bc0c597 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -448,13 +448,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): # Get the target length target_seqlen = cache_position[-1] + 1 - extended_attention_mask = torch.ones( - (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1), dtype=attention_mask.dtype, device=attention_mask.device, ) - # Filter out only the tokens that can be un-attended, this can happen # if one uses PaliGemma+ Fused modules where the cache on the # first iteration is already big enough, or if one passes custom cache @@ -467,6 +465,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + attention_mask = attention_mask.to(inputs_embeds.dtype) outputs = self.language_model( attention_mask=attention_mask, diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 935ceaf72d..fbd460ce79 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -430,6 +430,32 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase): EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + @slow + @require_torch + @require_read_token + def test_integration_detection_bug(self): + # this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context + # impacted negatively segmentation generations. + model_id = "google/paligemma-3b-pt-224" + model = PaliGemmaForConditionalGeneration.from_pretrained( + model_id, revision="bfloat16", torch_dtype=torch.bfloat16 + ).to(torch_device) + prompt = ("detect shoe",) + + image = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/shoe.png", + stream=True, + ).raw + ) + + inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = "detect shoe\n shoe" # fmt: skip + self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT) + @slow @require_read_token def test_paligemma_index_error_bug(self):