Fix paligemma detection inference (#31587)
* fix extended attention mask * add slow test for detection instance * [run-slow]paligemma
This commit is contained in:
@@ -448,13 +448,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
|
|
||||||
# Get the target length
|
# Get the target length
|
||||||
target_seqlen = cache_position[-1] + 1
|
target_seqlen = cache_position[-1] + 1
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
extended_attention_mask = torch.ones(
|
||||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
|
||||||
dtype=attention_mask.dtype,
|
dtype=attention_mask.dtype,
|
||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter out only the tokens that can be un-attended, this can happen
|
# Filter out only the tokens that can be un-attended, this can happen
|
||||||
# if one uses PaliGemma+ Fused modules where the cache on the
|
# if one uses PaliGemma+ Fused modules where the cache on the
|
||||||
# first iteration is already big enough, or if one passes custom cache
|
# first iteration is already big enough, or if one passes custom cache
|
||||||
@@ -467,6 +465,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|||||||
@@ -430,6 +430,32 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_read_token
|
||||||
|
def test_integration_detection_bug(self):
|
||||||
|
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
|
||||||
|
# impacted negatively segmentation generations.
|
||||||
|
model_id = "google/paligemma-3b-pt-224"
|
||||||
|
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
prompt = ("detect shoe",)
|
||||||
|
|
||||||
|
image = Image.open(
|
||||||
|
requests.get(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/shoe.png",
|
||||||
|
stream=True,
|
||||||
|
).raw
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
|
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
|
||||||
|
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_paligemma_index_error_bug(self):
|
def test_paligemma_index_error_bug(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user