[Flax] [WIP] allow loading head model with base model weights (#12255)

* boom boom

* remove flax clip example

* allow loading head model with base model weights

* add test

* fix imports

* disable save, load test for clip

* add test_save_load_to_base
This commit is contained in:
Suraj Patil
2021-06-21 20:26:42 +05:30
committed by GitHub
parent 8d5b7f36e5
commit eb881674f2
3 changed files with 66 additions and 1 deletions

View File

@@ -209,6 +209,13 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
# FlaxCLIPVisionModel does not have any base model
def test_save_load_from_base(self):
pass
def test_save_load_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
@@ -296,6 +303,13 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self)
# FlaxCLIPTextModel does not have any base model
def test_save_load_from_base(self):
pass
def test_save_load_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: