From e577bd0f13e1820650810f6864253d70dc76ce08 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Apr 2023 18:43:14 +0100 Subject: [PATCH] Use native TF checkpoints for the BLIP TF tests (#22593) * Use native TF checkpoints for the TF tests * Remove unneeded exceptions --- tests/models/blip/test_modeling_tf_blip.py | 25 ++++++------------- .../models/blip/test_modeling_tf_blip_text.py | 5 +--- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/models/blip/test_modeling_tf_blip.py b/tests/models/blip/test_modeling_tf_blip.py index b90205bd16..31630b17f9 100644 --- a/tests/models/blip/test_modeling_tf_blip.py +++ b/tests/models/blip/test_modeling_tf_blip.py @@ -189,10 +189,7 @@ class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - try: - model = TFBlipVisionModel.from_pretrained(model_name) - except OSError: - model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True) + model = TFBlipVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -320,10 +317,7 @@ class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - try: - model = TFBlipTextModel.from_pretrained(model_name) - except OSError: - model = TFBlipTextModel.from_pretrained(model_name, from_pt=True) + model = TFBlipTextModel.from_pretrained(model_name) self.assertIsNotNone(model) def test_pt_tf_model_equivalence(self, allow_missing_keys=True): @@ -432,7 +426,7 @@ class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase @slow def test_model_from_pretrained(self): for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = TFBlipModel.from_pretrained(model_name, from_pt=True) + model = TFBlipModel.from_pretrained(model_name) self.assertIsNotNone(model) def test_pt_tf_model_equivalence(self, allow_missing_keys=True): @@ -635,7 +629,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = TFBlipModel.from_pretrained(model_name, from_pt=True) + model = TFBlipModel.from_pretrained(model_name) self.assertIsNotNone(model) @unittest.skip(reason="Tested in individual model tests") @@ -750,10 +744,7 @@ class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - try: - model = TFBlipModel.from_pretrained(model_name) - except OSError: - model = TFBlipModel.from_pretrained(model_name, from_pt=True) + model = TFBlipModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -769,7 +760,7 @@ def prepare_img(): @slow class TFBlipModelIntegrationTest(unittest.TestCase): def test_inference_image_captioning(self): - model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True) + model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") image = prepare_img() @@ -796,7 +787,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase): ) def test_inference_vqa(self): - model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True) + model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") image = prepare_img() @@ -808,7 +799,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase): self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102]) def test_inference_itm(self): - model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True) + model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco") image = prepare_img() diff --git a/tests/models/blip/test_modeling_tf_blip_text.py b/tests/models/blip/test_modeling_tf_blip_text.py index 1b9a7cf37a..261056e918 100644 --- a/tests/models/blip/test_modeling_tf_blip_text.py +++ b/tests/models/blip/test_modeling_tf_blip_text.py @@ -160,10 +160,7 @@ class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - try: - model = TFBlipTextModel.from_pretrained(model_name) - except OSError: - model = TFBlipTextModel.from_pretrained(model_name, from_pt=True) + model = TFBlipTextModel.from_pretrained(model_name) self.assertIsNotNone(model) def test_pt_tf_model_equivalence(self, allow_missing_keys=True):