Fix loading CLIPVisionConfig and CLIPTextConfig (#16198)

* override from_pretrained

* add tests

* remove docstrings

* fix typo

* Trigger CI
This commit is contained in:
Suraj Patil
2022-03-16 16:24:01 +01:00
committed by GitHub
parent 09013efdf1
commit 190994573a
2 changed files with 51 additions and 0 deletions

View File

@@ -588,6 +588,21 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
self.assertTrue(models_equal)
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save CLIPConfig and check if we can load CLIPVisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = CLIPVisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save CLIPConfig and check if we can load CLIPTextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):