Remove deprecated use_flash_attention_2 parameter (#37131)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
Yuanyuan Chen
2025-06-02 17:06:25 +08:00
committed by GitHub
parent 51d732709e
commit fde1120b6c
4 changed files with 7 additions and 30 deletions

View File

@@ -488,7 +488,7 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model.save_pretrained(tmp_dir)
new_model = DiffLlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16
).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")