Pix2Struct: fix wrong broadcast axis of attention mask in visual encoder (#23976)
* fix wrong broadcast axis of attention mask in visual encoder * fix slow tests --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
@@ -210,7 +210,7 @@ class Pix2StructVisionAttention(nn.Module):
|
||||
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
|
||||
|
||||
if attention_mask.dim() == 2:
|
||||
position_bias = position_bias + attention_mask[:, None, :, None].to(position_bias.device)
|
||||
position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
|
||||
else:
|
||||
# (batch_size, n_heads, seq_length, key_length)
|
||||
position_bias = position_bias + attention_mask.to(position_bias.device)
|
||||
@@ -1695,7 +1695,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
||||
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
>>> print(generated_text)
|
||||
A picture of a stop sign with a red stop sign on it.
|
||||
A picture of a stop sign with a red stop sign
|
||||
```
|
||||
|
||||
Training:
|
||||
@@ -1719,7 +1719,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> print(f"{loss.item():.5f}")
|
||||
5.95566
|
||||
5.94282
|
||||
```"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -757,12 +757,12 @@ class Pix2StructIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
processor.decode(predictions[0], skip_special_tokens=True),
|
||||
"A picture of a stop sign with a red stop sign on it.",
|
||||
"A picture of a stop sign with a red stop sign",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
processor.decode(predictions[1], skip_special_tokens=True),
|
||||
"An photography of the Temple Bar and the Temple Bar.",
|
||||
"An photography of the Temple Bar and other places in the city.",
|
||||
)
|
||||
|
||||
def test_vqa_model(self):
|
||||
|
||||
Reference in New Issue
Block a user