From 7f04373865393f625fb8f20bdabdab188120f9b8 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:53:10 +0100 Subject: [PATCH] Explicitely specify `use_cache=True` in Flash Attention tests (#27635) explicit use_cache=True --- tests/models/mistral/test_modeling_mistral.py | 6 +++++- tests/test_modeling_common.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 31426435d0..0c28f46d5e 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -436,7 +436,11 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi # Just test that a large cache works as expected _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, ) @require_flash_attn diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9d9e96db43..c69b5ed77f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3166,7 +3166,11 @@ class ModelTesterMixin: # Just test that a large cache works as expected _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, ) @require_flash_attn