[FA / tests] Add use_cache tests for FA models (#26415)
* add use_cache tests for FA * fixup
This commit is contained in:
@@ -2908,6 +2908,35 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertTrue(torch.equal(out, out_fa))
|
self.assertTrue(torch.equal(out, out_fa))
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_generate_use_cache(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||||
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Just test that a large cache works as expected
|
||||||
|
_ = model.generate(
|
||||||
|
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user