Use native TF checkpoints for the BLIP TF tests (#22593)

* Use native TF checkpoints for the TF tests

* Remove unneeded exceptions
This commit is contained in:
Matt
2023-04-05 18:43:14 +01:00
committed by GitHub
parent 176ceff91f
commit e577bd0f13
2 changed files with 9 additions and 21 deletions

View File

@@ -189,10 +189,7 @@ class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try: model = TFBlipVisionModel.from_pretrained(model_name)
model = TFBlipVisionModel.from_pretrained(model_name)
except OSError:
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@@ -320,10 +317,7 @@ class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try: model = TFBlipTextModel.from_pretrained(model_name)
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True): def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
@@ -432,7 +426,7 @@ class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True) model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True): def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
@@ -635,7 +629,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True) model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(reason="Tested in individual model tests") @unittest.skip(reason="Tested in individual model tests")
@@ -750,10 +744,7 @@ class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try: model = TFBlipModel.from_pretrained(model_name)
model = TFBlipModel.from_pretrained(model_name)
except OSError:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@@ -769,7 +760,7 @@ def prepare_img():
@slow @slow
class TFBlipModelIntegrationTest(unittest.TestCase): class TFBlipModelIntegrationTest(unittest.TestCase):
def test_inference_image_captioning(self): def test_inference_image_captioning(self):
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True) model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
image = prepare_img() image = prepare_img()
@@ -796,7 +787,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
) )
def test_inference_vqa(self): def test_inference_vqa(self):
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True) model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
image = prepare_img() image = prepare_img()
@@ -808,7 +799,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102]) self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])
def test_inference_itm(self): def test_inference_itm(self):
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True) model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco") processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
image = prepare_img() image = prepare_img()

View File

@@ -160,10 +160,7 @@ class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try: model = TFBlipTextModel.from_pretrained(model_name)
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True): def test_pt_tf_model_equivalence(self, allow_missing_keys=True):