Fix loading CLIPVisionConfig and CLIPTextConfig (#16198)
* override from_pretrained * add tests * remove docstrings * fix typo * Trigger CI
This commit is contained in:
@@ -588,6 +588,21 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save CLIPConfig and check if we can load CLIPVisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = CLIPVisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save CLIPConfig and check if we can load CLIPTextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
|
||||
Reference in New Issue
Block a user