Fix flex_attention in training mode (#35605)

* fix flex

* add test

* style
This commit is contained in:
Cyril Vallez
2025-01-10 11:49:12 +01:00
committed by GitHub
parent a9bd1e6284
commit ccc0381d36
2 changed files with 14 additions and 1 deletions

View File

@@ -4807,6 +4807,19 @@ class ModelTesterMixin:
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))
@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes:
if not model_class._supports_flex_attn:
self.skipTest(reason="This model does not support flex attention")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device, dtype=torch.float16)
self.assertTrue(model.config._attn_implementation == "flex_attention")
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(inputs_dict["input_ids"].to(torch_device))
global_rng = random.Random()