Fix falcon with SDPA, alibi but no passed mask (#30123)
* fix falcon without attention_mask & alibi * add test * Update tests/models/falcon/test_modeling_falcon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1098,16 +1098,12 @@ class FalconModel(FalconPreTrainedModel):
|
||||
elif head_mask is None:
|
||||
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
||||
|
||||
attention_mask_2d = attention_mask
|
||||
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# We take care to integrate alibi bias in the attention_mask here.
|
||||
if attention_mask_2d is None:
|
||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||
else:
|
||||
min_dtype = torch.finfo(alibi.dtype).min
|
||||
attention_mask = torch.masked_fill(
|
||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||
|
||||
@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
||||
self.assertEqual(unpadded_gen_text[0], expected_output)
|
||||
self.assertEqual(padded_gen_text[0], expected_output)
|
||||
|
||||
@slow
|
||||
@require_torch_sdpa
|
||||
def test_falcon_alibi_sdpa_matches_eager(self):
|
||||
input_ids = torch.randint(0, 1000, (5, 20))
|
||||
|
||||
config = FalconConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=4,
|
||||
new_decoder_architecture=True,
|
||||
alibi=True,
|
||||
)
|
||||
|
||||
falcon = FalconForCausalLM(config)
|
||||
falcon = falcon.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# output_attentions=True dispatches to eager path
|
||||
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
|
||||
falcon_output_sdpa = falcon(input_ids)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user