Remove deprecated use_flash_attention_2 parameter (#37131)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -488,7 +488,7 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = DiffLlamaForCausalLM.from_pretrained(
|
||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
||||
tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
||||
|
||||
Reference in New Issue
Block a user