[FlaxClip] fix test from/save pretrained test (#12284)
* boom boom * remove flax clip example * fix from_save_pretrained
This commit is contained in:
@@ -511,3 +511,37 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
# overwrite from common since FlaxCLIPModel returns nested output
|
||||||
|
# which is not supported in the common test
|
||||||
|
def test_from_pretrained_save_pretrained(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class.__name__ != "FlaxBertModel":
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
outputs = model(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
# verify that normal save_pretrained works as expected
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()[:4]
|
||||||
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||||
|
|
||||||
|
# verify that save_pretrained for distributed training
|
||||||
|
# with `params=params` works as expected
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname, params=model.params)
|
||||||
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()[:4]
|
||||||
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||||
|
|||||||
Reference in New Issue
Block a user