From a5a8291ad10bda4933c1859b40155571567297f6 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Tue, 13 Aug 2024 10:46:21 +0200 Subject: [PATCH] Fix tests (#32649) * skip failing tests * [no-filter] * [no-filter] * fix wording catch in FA2 test * [no-filter] * trigger normal CI without filtering --- tests/test_modeling_common.py | 6 +++++- tests/utils/test_modeling_utils.py | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 898f8689d7..203146a808 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2831,7 +2831,11 @@ class ModelTesterMixin: model_forward_args = inspect.signature(model.forward).parameters if "inputs_embeds" not in model_forward_args: self.skipTest(reason="This model doesn't use `inputs_embeds`") - + has_inputs_embeds_forwarding = "inputs_embeds" in set( + inspect.signature(model.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + self.skipTest(reason="This model doesn't support `inputs_embeds` passed to `generate`.") inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2db96d87c7..a91aa5b9f4 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2427,8 +2427,7 @@ class TestAttentionImplementation(unittest.TestCase): _ = AutoModel.from_pretrained( "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" ) - - self.assertTrue("the package flash_attn seems not to be installed" in str(cm.exception)) + self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) def test_not_available_flash_with_config(self): if is_flash_attn_2_available(): @@ -2443,7 +2442,7 @@ class TestAttentionImplementation(unittest.TestCase): attn_implementation="flash_attention_2", ) - self.assertTrue("the package flash_attn seems not to be installed" in str(cm.exception)) + self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) def test_not_available_sdpa(self): if is_torch_sdpa_available():