Use native TF checkpoints for the BLIP TF tests (#22593)
* Use native TF checkpoints for the TF tests * Remove unneeded exceptions
This commit is contained in:
@@ -189,10 +189,7 @@ class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipVisionModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
|
||||
model = TFBlipVisionModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@@ -320,10 +317,7 @@ class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipTextModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
|
||||
model = TFBlipTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
|
||||
@@ -432,7 +426,7 @@ class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
|
||||
model = TFBlipModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
|
||||
@@ -635,7 +629,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
|
||||
model = TFBlipModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Tested in individual model tests")
|
||||
@@ -750,10 +744,7 @@ class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
|
||||
model = TFBlipModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@@ -769,7 +760,7 @@ def prepare_img():
|
||||
@slow
|
||||
class TFBlipModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_image_captioning(self):
|
||||
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True)
|
||||
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
image = prepare_img()
|
||||
|
||||
@@ -796,7 +787,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_inference_vqa(self):
|
||||
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True)
|
||||
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
||||
|
||||
image = prepare_img()
|
||||
@@ -808,7 +799,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])
|
||||
|
||||
def test_inference_itm(self):
|
||||
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True)
|
||||
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
|
||||
|
||||
image = prepare_img()
|
||||
|
||||
@@ -160,10 +160,7 @@ class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipTextModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
|
||||
model = TFBlipTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
|
||||
|
||||
Reference in New Issue
Block a user