Explicitely specify use_cache=True in Flash Attention tests (#27635)
explicit use_cache=True
This commit is contained in:
@@ -436,7 +436,11 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
# Just test that a large cache works as expected
|
# Just test that a large cache works as expected
|
||||||
_ = model.generate(
|
_ = model.generate(
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
dummy_input,
|
||||||
|
attention_mask=dummy_attention_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=False,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
|
|||||||
@@ -3166,7 +3166,11 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Just test that a large cache works as expected
|
# Just test that a large cache works as expected
|
||||||
_ = model.generate(
|
_ = model.generate(
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
dummy_input,
|
||||||
|
attention_mask=dummy_attention_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=False,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
|
|||||||
Reference in New Issue
Block a user