Pix2Struct: fix wrong broadcast axis of attention mask in visual encoder (#23976)
* fix wrong broadcast axis of attention mask in visual encoder * fix slow tests --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
@@ -210,7 +210,7 @@ class Pix2StructVisionAttention(nn.Module):
|
|||||||
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
|
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
|
||||||
|
|
||||||
if attention_mask.dim() == 2:
|
if attention_mask.dim() == 2:
|
||||||
position_bias = position_bias + attention_mask[:, None, :, None].to(position_bias.device)
|
position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
|
||||||
else:
|
else:
|
||||||
# (batch_size, n_heads, seq_length, key_length)
|
# (batch_size, n_heads, seq_length, key_length)
|
||||||
position_bias = position_bias + attention_mask.to(position_bias.device)
|
position_bias = position_bias + attention_mask.to(position_bias.device)
|
||||||
@@ -1695,7 +1695,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
|
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
|
||||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
>>> print(generated_text)
|
>>> print(generated_text)
|
||||||
A picture of a stop sign with a red stop sign on it.
|
A picture of a stop sign with a red stop sign
|
||||||
```
|
```
|
||||||
|
|
||||||
Training:
|
Training:
|
||||||
@@ -1719,7 +1719,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
>>> outputs = model(**inputs, labels=labels)
|
>>> outputs = model(**inputs, labels=labels)
|
||||||
>>> loss = outputs.loss
|
>>> loss = outputs.loss
|
||||||
>>> print(f"{loss.item():.5f}")
|
>>> print(f"{loss.item():.5f}")
|
||||||
5.95566
|
5.94282
|
||||||
```"""
|
```"""
|
||||||
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|||||||
@@ -757,12 +757,12 @@ class Pix2StructIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processor.decode(predictions[0], skip_special_tokens=True),
|
processor.decode(predictions[0], skip_special_tokens=True),
|
||||||
"A picture of a stop sign with a red stop sign on it.",
|
"A picture of a stop sign with a red stop sign",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processor.decode(predictions[1], skip_special_tokens=True),
|
processor.decode(predictions[1], skip_special_tokens=True),
|
||||||
"An photography of the Temple Bar and the Temple Bar.",
|
"An photography of the Temple Bar and other places in the city.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_vqa_model(self):
|
def test_vqa_model(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user