Flash-Attn: fix generation when no attention mask or no pading (#32241)

* fix

* fix prev test (half of failures)

* [run-slow] llama, gemma2

* [run-slow] llama, gemma2
This commit is contained in:
Raushan Turganbay
2024-07-26 14:45:55 +05:00
committed by GitHub
parent 27c7f971c0
commit 81233c069c
2 changed files with 18 additions and 3 deletions

View File

@@ -4270,6 +4270,18 @@ class ModelTesterMixin:
use_cache=True,
)
# Generate with one batch only to test generation when attention mask will be None
# when real inputs are used, because there is no padding. See issue #32237 for more
dummy_input = dummy_input[:1, ...]
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@@ -4342,6 +4354,8 @@ class ModelTesterMixin:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
@@ -4356,7 +4370,6 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)