From 6dc4c36acbaddcfd7493f71d459ccf123c70e523 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 4 May 2022 18:39:30 +0200 Subject: [PATCH] minor change on TF Data2Vec test (#17085) Co-authored-by: ydshieh --- .../models/data2vec/modeling_tf_data2vec_vision.py | 2 +- .../data2vec/test_modeling_tf_data2vec_vision.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index 35dc46d120..4c3446dc06 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -52,7 +52,7 @@ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768] _IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k" _IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote" -DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [ +TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/data2vec-vision-base-ft1k", # See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision ] diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py index 7c734a1e85..17b02d037c 100644 --- a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py @@ -28,17 +28,13 @@ from ...test_configuration_common import ConfigTester from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor -DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "facebook/data2vec-vision-base-ft1k", - # See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision -] - - if is_tf_available(): import tensorflow as tf from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel - + from transformers.models.data2vec.modeling_tf_data2vec_vision import ( + TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, + ) if is_vision_available(): from PIL import Image @@ -421,7 +417,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): - for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + for model_name in TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = TFData2VecVisionModel.from_pretrained(model_name) self.assertIsNotNone(model)