@@ -1608,7 +1608,6 @@ class GenerationTesterMixin:
|
|||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
@slow # TODO (Joao): fix GPTBigCode
|
|
||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
# The check done in this test is fairly difficult -- depending on the model architecture, passing the right
|
# The check done in this test is fairly difficult -- depending on the model architecture, passing the right
|
||||||
# position index for the position embeddings can still result in a different output, due to numerical masking.
|
# position index for the position embeddings can still result in a different output, due to numerical masking.
|
||||||
@@ -1648,7 +1647,7 @@ class GenerationTesterMixin:
|
|||||||
position_ids.masked_fill_(padded_attention_mask == 0, 1)
|
position_ids.masked_fill_(padded_attention_mask == 0, 1)
|
||||||
model_kwargs["position_ids"] = position_ids
|
model_kwargs["position_ids"] = position_ids
|
||||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
if not torch.allclose(next_logits_wo_padding, next_logits_with_padding):
|
if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7):
|
||||||
no_failures = False
|
no_failures = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user