Fix falcon with SDPA, alibi but no passed mask (#30123)
* fix falcon without attention_mask & alibi * add test * Update tests/models/falcon/test_modeling_falcon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
||||
self.assertEqual(unpadded_gen_text[0], expected_output)
|
||||
self.assertEqual(padded_gen_text[0], expected_output)
|
||||
|
||||
@slow
|
||||
@require_torch_sdpa
|
||||
def test_falcon_alibi_sdpa_matches_eager(self):
|
||||
input_ids = torch.randint(0, 1000, (5, 20))
|
||||
|
||||
config = FalconConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=4,
|
||||
new_decoder_architecture=True,
|
||||
alibi=True,
|
||||
)
|
||||
|
||||
falcon = FalconForCausalLM(config)
|
||||
falcon = falcon.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# output_attentions=True dispatches to eager path
|
||||
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
|
||||
falcon_output_sdpa = falcon(input_ids)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user