Flash-Attn: fix generation when no attention mask or no pading (#32241)
* fix * fix prev test (half of failures) * [run-slow] llama, gemma2 * [run-slow] llama, gemma2
This commit is contained in:
committed by
GitHub
parent
27c7f971c0
commit
81233c069c
@@ -4270,6 +4270,18 @@ class ModelTesterMixin:
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Generate with one batch only to test generation when attention mask will be None
|
||||
# when real inputs are used, because there is no padding. See issue #32237 for more
|
||||
dummy_input = dummy_input[:1, ...]
|
||||
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
|
||||
_ = model.generate(
|
||||
dummy_input,
|
||||
attention_mask=dummy_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@@ -4342,6 +4354,8 @@ class ModelTesterMixin:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
@@ -4356,7 +4370,6 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
|
||||
Reference in New Issue
Block a user