Fix flex_attention in training mode (#35605)
* fix flex * add test * style
This commit is contained in:
@@ -27,7 +27,7 @@ def flex_attention_forward(
|
|||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
score = softcap * torch.tanh(score / softcap)
|
score = softcap * torch.tanh(score / softcap)
|
||||||
if causal_mask is not None:
|
if causal_mask is not None:
|
||||||
score += causal_mask[b][0][q_idx][kv_idx]
|
score = score + causal_mask[b][0][q_idx][kv_idx]
|
||||||
return score
|
return score
|
||||||
|
|
||||||
attn_output, attention_weights = flex_attention(
|
attn_output, attention_weights = flex_attention(
|
||||||
|
|||||||
@@ -4807,6 +4807,19 @@ class ModelTesterMixin:
|
|||||||
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
|
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
|
||||||
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))
|
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_flex_attention_with_grads(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_flex_attn:
|
||||||
|
self.skipTest(reason="This model does not support flex attention")
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config._attn_implementation = "flex_attention"
|
||||||
|
model = model_class(config).to(device=torch_device, dtype=torch.float16)
|
||||||
|
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||||
|
|
||||||
|
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
||||||
|
_ = model(inputs_dict["input_ids"].to(torch_device))
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user