From 60e448c87eff29b166bf2821f5389056a72343e3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 5 Aug 2021 18:38:50 +0200 Subject: [PATCH] [Flax] Correct pt to flax conversion if from base to head (#13006) * finish PR * add tests * correct tests * finish * correct other flax tests * better naming * correct naming * finish * apply sylvains suggestions --- .../modeling_flax_pytorch_utils.py | 108 ++++++++++++------ tests/test_modeling_flax_clip.py | 22 ++++ tests/test_modeling_flax_common.py | 57 +++++++++ tests/test_modeling_flax_t5.py | 69 ++++++++++- 4 files changed, 220 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 8d9bdd91ef..7b1588e95b 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -17,6 +17,7 @@ import os from pickle import UnpicklingError +from typing import Dict, Tuple import numpy as np @@ -58,60 +59,97 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa return flax_state_dict +def rename_key_and_reshape_tensor( + pt_tuple_key: Tuple[str], + pt_tensor: np.ndarray, + random_flax_state_dict: Dict[str, jnp.ndarray], + model_prefix: str, +) -> (Tuple[str], np.ndarray): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + + def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: + """Checks if ``key`` of ``(prefix,) + key`` is in random_flax_state_dict""" + return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0 + + # layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # embedding + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key): + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + model_prefix = flax_model.base_model_prefix random_flax_state_dict = flatten_dict(flax_model.params) flax_state_dict = {} - remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( - flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) - add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( - flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) - # Need to change some parameters name to match Flax names so that we don't have to fork any layer + # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) - has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix - require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict - - if remove_base_model_prefix and has_base_model_prefix: + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: pt_tuple_key = pt_tuple_key[1:] - elif add_base_model_prefix and require_base_model_prefix: - pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key # Correctly rename weight parameters - if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: - pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: - pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) - elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: - # conv layer - pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - pt_tensor = pt_tensor.transpose(2, 3, 1, 0) - elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: - # linear layer - pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - pt_tensor = pt_tensor.T - elif pt_tuple_key[-1] == "gamma": - pt_tuple_key = pt_tuple_key[:-1] + ("weight",) - elif pt_tuple_key[-1] == "beta": - pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) - if pt_tuple_key in random_flax_state_dict: - if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: raise ValueError( f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " - f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) # also add unexpected weight so that warning is thrown - flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) @@ -154,10 +192,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): flax_state_dict = flatten_dict(flax_state) pt_model_dict = pt_model.state_dict() - remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and ( + load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()]) ) - add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and ( + load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()]) ) @@ -170,9 +208,9 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict # adapt flax_key to prepare for loading from/to base model only - if remove_base_model_prefix and has_base_model_prefix: + if load_model_with_head_into_base_model and has_base_model_prefix: flax_key_tuple = flax_key_tuple[1:] - elif add_base_model_prefix and require_base_model_prefix: + elif load_base_model_into_model_with_head and require_base_model_prefix: flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple # rename flax weights to PyTorch format diff --git a/tests/test_modeling_flax_clip.py b/tests/test_modeling_flax_clip.py index da1fcd68ac..8036350530 100644 --- a/tests/test_modeling_flax_clip.py +++ b/tests/test_modeling_flax_clip.py @@ -213,9 +213,20 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_save_load_from_base(self): pass + # FlaxCLIPVisionModel does not have any base model def test_save_load_to_base(self): pass + # FlaxCLIPVisionModel does not have any base model + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + pass + + # FlaxCLIPVisionModel does not have any base model + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + pass + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: @@ -307,9 +318,20 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_save_load_from_base(self): pass + # FlaxCLIPVisionModel does not have any base model def test_save_load_to_base(self): pass + # FlaxCLIPVisionModel does not have any base model + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + pass + + # FlaxCLIPVisionModel does not have any base model + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + pass + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index c959a04ac2..03139ac73a 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -334,6 +334,63 @@ class FlaxModelTesterMixin: max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @slow def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_flax_t5.py b/tests/test_modeling_flax_t5.py index 9a635e07f1..424ec0b420 100644 --- a/tests/test_modeling_flax_t5.py +++ b/tests/test_modeling_flax_t5.py @@ -17,8 +17,15 @@ import unittest import numpy as np +import transformers from transformers import is_flax_available -from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_flax, + require_sentencepiece, + require_tokenizers, + slow, +) from .test_configuration_common import ConfigTester from .test_generation_flax_utils import FlaxGenerationTesterMixin @@ -40,6 +47,7 @@ if is_flax_available(): from flax.training.common_utils import onehot from flax.traverse_util import flatten_dict from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right @@ -363,6 +371,65 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest. max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @require_sentencepiece @require_tokenizers