From df323476a3abc9ad256f50858d86e979dc236795 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Wed, 14 Aug 2024 16:36:17 +0530 Subject: [PATCH] fix: Fixed failing tests in `tests/utils/test_add_new_model_like.py` (#32678) * Fixed failing tests in tests/utils/test_add_new_model_like.py * Fixed formatting using ruff. * Small nit. --- tests/utils/test_add_new_model_like.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_add_new_model_like.py b/tests/utils/test_add_new_model_like.py index 1eb6d56e6c..27e53ed063 100644 --- a/tests/utils/test_add_new_model_like.py +++ b/tests/utils/test_add_new_model_like.py @@ -61,6 +61,7 @@ VIT_MODEL_FILES = { "src/transformers/models/vit/convert_vit_timm_to_pytorch.py", "src/transformers/models/vit/feature_extraction_vit.py", "src/transformers/models/vit/image_processing_vit.py", + "src/transformers/models/vit/image_processing_vit_fast.py", "src/transformers/models/vit/modeling_vit.py", "src/transformers/models/vit/modeling_tf_vit.py", "src/transformers/models/vit/modeling_flax_vit.py", @@ -662,7 +663,13 @@ NEW_BERT_CONSTANT = "value" def test_retrieve_model_classes(self): gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()} expected_gpt_classes = { - "pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"}, + "pt": { + "GPT2ForTokenClassification", + "GPT2Model", + "GPT2LMHeadModel", + "GPT2ForSequenceClassification", + "GPT2ForQuestionAnswering", + }, "tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"}, "flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"}, } @@ -836,7 +843,7 @@ NEW_BERT_CONSTANT = "value" ] expected_model_classes = { "pt": set(wav2vec2_classes), - "tf": {f"TF{m}" for m in wav2vec2_classes[:1]}, + "tf": {f"TF{m}" for m in [wav2vec2_classes[0], wav2vec2_classes[-2]]}, "flax": {f"Flax{m}" for m in wav2vec2_classes[:2]}, } @@ -870,7 +877,7 @@ NEW_BERT_CONSTANT = "value" self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2") self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2") self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2") - self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2") + self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV2VEC2") self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config") self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor") self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor")