Fix paligemma detection inference (#31587)

* fix extended attention mask

* add slow test for detection instance

* [run-slow]paligemma
This commit is contained in:
Pablo Montalvo
2024-06-26 19:17:09 +02:00
committed by GitHub
parent e71f2863d7
commit 492ee17ec3
2 changed files with 28 additions and 3 deletions

View File

@@ -430,6 +430,32 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_read_token
def test_integration_detection_bug(self):
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
# impacted negatively segmentation generations.
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
).to(torch_device)
prompt = ("detect shoe",)
image = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/shoe.png",
stream=True,
).raw
)
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_read_token
def test_paligemma_index_error_bug(self):