From eb881674f23f4dc90b9dfb03dce9ef918f4f8f9d Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 21 Jun 2021 20:26:42 +0530 Subject: [PATCH] [Flax] [WIP] allow loading head model with base model weights (#12255) * boom boom * remove flax clip example * allow loading head model with base model weights * add test * fix imports * disable save, load test for clip * add test_save_load_to_base --- src/transformers/modeling_flax_utils.py | 5 +++ tests/test_modeling_flax_clip.py | 14 ++++++++ tests/test_modeling_flax_common.py | 48 ++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 0691eab3a8..6a2855edf2 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -348,6 +348,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state: + state = {cls.base_model_prefix: state} + # flatten dicts state = flatten_dict(state) diff --git a/tests/test_modeling_flax_clip.py b/tests/test_modeling_flax_clip.py index 7666c13bd7..da1fcd68ac 100644 --- a/tests/test_modeling_flax_clip.py +++ b/tests/test_modeling_flax_clip.py @@ -209,6 +209,13 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase): [self.model_tester.num_attention_heads, seq_length, seq_length], ) + # FlaxCLIPVisionModel does not have any base model + def test_save_load_from_base(self): + pass + + def test_save_load_to_base(self): + pass + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: @@ -296,6 +303,13 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxCLIPTextModelTester(self) + # FlaxCLIPTextModel does not have any base model + def test_save_load_from_base(self): + pass + + def test_save_load_to_base(self): + pass + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 10cc1f4538..f2d30eea41 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -32,7 +32,9 @@ if is_flax_available(): import jax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla - from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING + from flax.core.frozen_dict import unfreeze + from flax.traverse_util import flatten_dict + from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING from transformers.modeling_flax_pytorch_utils import ( convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, @@ -273,6 +275,50 @@ class FlaxModelTesterMixin: for output_loaded, output in zip(outputs_loaded, outputs): self.assert_almost_equals(output_loaded, output, 1e-3) + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @slow def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()