Fix falcon with SDPA, alibi but no passed mask (#30123)

* fix falcon without attention_mask & alibi

* add test

* Update tests/models/falcon/test_modeling_falcon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
fxmarty
2024-04-08 16:25:07 +02:00
committed by GitHub
parent 1773afcec3
commit 1897874edc
2 changed files with 34 additions and 14 deletions

View File

@@ -1098,16 +1098,12 @@ class FalconModel(FalconPreTrainedModel):
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),

View File

@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase):
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
self.assertEqual(unpadded_gen_text[0], expected_output)
self.assertEqual(padded_gen_text[0], expected_output)
@slow
@require_torch_sdpa
def test_falcon_alibi_sdpa_matches_eager(self):
input_ids = torch.randint(0, 1000, (5, 20))
config = FalconConfig(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=3,
num_attention_heads=4,
new_decoder_architecture=True,
alibi=True,
)
falcon = FalconForCausalLM(config)
falcon = falcon.eval()
with torch.no_grad():
# output_attentions=True dispatches to eager path
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
falcon_output_sdpa = falcon(input_ids)[0]
self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))