Fix falcon with SDPA, alibi but no passed mask (#30123)
* fix falcon without attention_mask & alibi * add test * Update tests/models/falcon/test_modeling_falcon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1098,27 +1098,23 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
elif head_mask is None:
|
elif head_mask is None:
|
||||||
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
||||||
|
|
||||||
attention_mask_2d = attention_mask
|
|
||||||
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# We take care to integrate alibi bias in the attention_mask here.
|
# We take care to integrate alibi bias in the attention_mask here.
|
||||||
if attention_mask_2d is None:
|
min_dtype = torch.finfo(alibi.dtype).min
|
||||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
attention_mask = torch.masked_fill(
|
||||||
else:
|
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||||
min_dtype = torch.finfo(alibi.dtype).min
|
attention_mask < -1,
|
||||||
attention_mask = torch.masked_fill(
|
min_dtype,
|
||||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
)
|
||||||
attention_mask < -1,
|
|
||||||
min_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
if seq_length > 1 and attention_mask.device.type == "cuda":
|
if seq_length > 1 and attention_mask.device.type == "cuda":
|
||||||
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
||||||
else:
|
else:
|
||||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
|||||||
@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
|||||||
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
||||||
self.assertEqual(unpadded_gen_text[0], expected_output)
|
self.assertEqual(unpadded_gen_text[0], expected_output)
|
||||||
self.assertEqual(padded_gen_text[0], expected_output)
|
self.assertEqual(padded_gen_text[0], expected_output)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_sdpa
|
||||||
|
def test_falcon_alibi_sdpa_matches_eager(self):
|
||||||
|
input_ids = torch.randint(0, 1000, (5, 20))
|
||||||
|
|
||||||
|
config = FalconConfig(
|
||||||
|
vocab_size=1000,
|
||||||
|
hidden_size=64,
|
||||||
|
num_hidden_layers=3,
|
||||||
|
num_attention_heads=4,
|
||||||
|
new_decoder_architecture=True,
|
||||||
|
alibi=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
falcon = FalconForCausalLM(config)
|
||||||
|
falcon = falcon.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# output_attentions=True dispatches to eager path
|
||||||
|
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
|
||||||
|
falcon_output_sdpa = falcon(input_ids)[0]
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user