minor change on TF Data2Vec test (#17085)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-05-04 18:39:30 +02:00
committed by GitHub
parent 23619ef6b7
commit 6dc4c36acb
2 changed files with 5 additions and 9 deletions

View File

@@ -52,7 +52,7 @@ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/data2vec-vision-base-ft1k",
# See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision
]

View File

@@ -28,17 +28,13 @@ from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/data2vec-vision-base-ft1k",
# See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision
]
if is_tf_available():
import tensorflow as tf
from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel
from transformers.models.data2vec.modeling_tf_data2vec_vision import (
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
)
if is_vision_available():
from PIL import Image
@@ -421,7 +417,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFData2VecVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)