Fix paligemma detection inference (#31587)
* fix extended attention mask * add slow test for detection instance * [run-slow]paligemma
This commit is contained in:
@@ -430,6 +430,32 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_integration_detection_bug(self):
|
||||
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
|
||||
# impacted negatively segmentation generations.
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
prompt = ("detect shoe",)
|
||||
|
||||
image = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/shoe.png",
|
||||
stream=True,
|
||||
).raw
|
||||
)
|
||||
|
||||
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
|
||||
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_paligemma_index_error_bug(self):
|
||||
|
||||
Reference in New Issue
Block a user