minor change on TF Data2Vec test (#17085)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -52,7 +52,7 @@ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
|||||||
_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
|
_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
|
||||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
|
||||||
|
|
||||||
DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/data2vec-vision-base-ft1k",
|
"facebook/data2vec-vision-base-ft1k",
|
||||||
# See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision
|
# See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -28,17 +28,13 @@ from ...test_configuration_common import ConfigTester
|
|||||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
||||||
"facebook/data2vec-vision-base-ft1k",
|
|
||||||
# See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel
|
from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel
|
||||||
|
from transformers.models.data2vec.modeling_tf_data2vec_vision import (
|
||||||
|
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -421,7 +417,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
model = TFData2VecVisionModel.from_pretrained(model_name)
|
model = TFData2VecVisionModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user