From 8d5b7f36e51122453d5b117f64d110135e8d40dd Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 21 Jun 2021 20:24:34 +0530 Subject: [PATCH] [FlaxClip] fix test from/save pretrained test (#12284) * boom boom * remove flax clip example * fix from_save_pretrained --- tests/test_modeling_flax_clip.py | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_modeling_flax_clip.py b/tests/test_modeling_flax_clip.py index 8a82b94ca9..7666c13bd7 100644 --- a/tests/test_modeling_flax_clip.py +++ b/tests/test_modeling_flax_clip.py @@ -511,3 +511,37 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): ) for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + # overwrite from common since FlaxCLIPModel returns nested output + # which is not supported in the common test + def test_from_pretrained_save_pretrained(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class.__name__ != "FlaxBertModel": + continue + + with self.subTest(model_class.__name__): + model = model_class(config) + + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**prepared_inputs_dict).to_tuple() + + # verify that normal save_pretrained works as expected + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_loaded = model_class.from_pretrained(tmpdirname) + + outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()[:4] + for output_loaded, output in zip(outputs_loaded, outputs): + self.assert_almost_equals(output_loaded, output, 1e-3) + + # verify that save_pretrained for distributed training + # with `params=params` works as expected + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, params=model.params) + model_loaded = model_class.from_pretrained(tmpdirname) + + outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()[:4] + for output_loaded, output in zip(outputs_loaded, outputs): + self.assert_almost_equals(output_loaded, output, 1e-3)