From f7ca656f07537762d03d9f4084fbcfe7246095d0 Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Fri, 24 Feb 2023 13:17:33 +0530 Subject: [PATCH] [Flax] adding support for batch norm layers (#21581) * [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * cleanup, batchnorm support in sharded pt to flax * support for batchnorm tests in pt+flax integration * simplifying checking batch norm layer --- .../modeling_flax_pytorch_utils.py | 93 ++++++++++++++++--- src/transformers/modeling_flax_utils.py | 40 ++++++-- tests/test_modeling_flax_common.py | 44 +++++++-- 3 files changed, 149 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index c78b1b44cd..5de6a985df 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -83,6 +83,16 @@ def rename_key_and_reshape_tensor( if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): return renamed_pt_tuple_key, pt_tensor + # batch norm layer mean + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",) + if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # batch norm layer var + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",) + if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + # embedding renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): @@ -118,13 +128,25 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix - random_flax_state_dict = flatten_dict(flax_model.params) + + # use params dict if the model contains batch norm layers + if "params" in flax_model.params: + flax_model_params = flax_model.params["params"] + else: + flax_model_params = flax_model.params + random_flax_state_dict = flatten_dict(flax_model_params) + + # add batch_stats keys,values to dict + if "batch_stats" in flax_model.params: + flax_batch_stats = flatten_dict(flax_model.params["batch_stats"]) + random_flax_state_dict.update(flax_batch_stats) + flax_state_dict = {} - load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and ( model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()} ) - load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + load_base_model_into_model_with_head = (model_prefix in flax_model_params) and ( model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()} ) @@ -154,8 +176,22 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) - # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + # add batch stats if the model contains batchnorm layers + if "batch_stats" in flax_model.params: + if "mean" in flax_key[-1] or "var" in flax_key[-1]: + flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) + continue + # remove num_batches_tracked key + if "num_batches_tracked" in flax_key[-1]: + flax_state_dict.pop(flax_key, None) + continue + + # also add unexpected weight so that warning is thrown + flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) + + else: + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) @@ -176,12 +212,21 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix - random_flax_state_dict = flatten_dict(flax_model.params) - load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict + if "batch_stats" in flax_model.params: + flax_model_params = flax_model.params["params"] + + random_flax_state_dict = flatten_dict(flax_model_params) + random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"])) + else: + flax_model_params = flax_model.params + random_flax_state_dict = flatten_dict(flax_model_params) + + load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and ( model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()} ) - load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + load_base_model_into_model_with_head = (model_prefix in flax_model_params) and ( model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()} ) # Need to change some parameters name to match Flax names @@ -209,8 +254,25 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) - # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + # add batch stats if the model contains batchnorm layers + if "batch_stats" in flax_model.params: + if "mean" in flax_key[-1]: + flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) + continue + if "var" in flax_key[-1]: + flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) + continue + # remove num_batches_tracked key + if "num_batches_tracked" in flax_key[-1]: + flax_state_dict.pop(flax_key, None) + continue + + # also add unexpected weight so that warning is thrown + flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) + + else: + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) @@ -299,7 +361,16 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): elif flax_key_tuple[-1] in ["scale", "embedding"]: flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - flax_key = ".".join(flax_key_tuple) + # adding batch stats from flax batch norm to pt + elif "mean" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",) + elif "var" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_var",) + + if "batch_stats" in flax_state: + flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header + else: + flax_key = ".".join(flax_key_tuple) if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 466f324ce8..271d42421b 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -837,14 +837,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): # keep the params on CPU if we don't want to initialize state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) - # if model is base model only use model_prefix key - if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: - state = state[cls.base_model_prefix] + if "batch_stats" in state: # if flax model contains batch norm layers + # if model is base model only use model_prefix key + if ( + cls.base_model_prefix not in dict(model.params_shape_tree["params"]) + and cls.base_model_prefix in state["params"] + ): + state["params"] = state["params"][cls.base_model_prefix] + state["batch_stats"] = state["batch_stats"][cls.base_model_prefix] - # if model is head model and we are loading weights from base model - # we initialize new params dict with base_model_prefix - if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: - state = {cls.base_model_prefix: state} + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if ( + cls.base_model_prefix in dict(model.params_shape_tree["params"]) + and cls.base_model_prefix not in state["params"] + ): + state = { + "params": {cls.base_model_prefix: state["params"]}, + "batch_stats": {cls.base_model_prefix: state["batch_stats"]}, + } + + else: + # if model is base model only use model_prefix key + if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: + state = state[cls.base_model_prefix] + + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: + state = {cls.base_model_prefix: state} # flatten dicts state = flatten_dict(state) @@ -854,6 +875,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params + # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked + for unexpected_key in unexpected_keys.copy(): + if "num_batches_tracked" in unexpected_key[-1]: + unexpected_keys.remove(unexpected_key) + if missing_keys and not _do_init: logger.warning( f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f93228e9b8..76e4d6c881 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None): return attn_mask +def get_params(params, from_head_prefix=None): + """Function extracts relevant parameters into flatten dict from model params, + appends batch normalization statistics if present""" + + # If Both parameters and batch normalization statistics are present + if "batch_stats" in params: + # Extract only parameters for the specified head prefix (if specified) and add batch statistics + if from_head_prefix is not None: + extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix])) + extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix])) + else: + extracted_params = flatten_dict(unfreeze(params["params"])) + extracted_params.update(flatten_dict(params["batch_stats"])) + + # Only parameters are present + else: + if from_head_prefix is not None: + extracted_params = flatten_dict(unfreeze(params[from_head_prefix])) + else: + extracted_params = flatten_dict(unfreeze(params)) + + return extracted_params + + @require_flax class FlaxModelTesterMixin: model_tester = None @@ -426,14 +450,14 @@ class FlaxModelTesterMixin: continue model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) + base_params = get_params(model.params) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) head_model = model_class.from_pretrained(tmpdirname) - base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix) for key in base_param_from_head.keys(): max_diff = (base_params[key] - base_param_from_head[key]).sum().item() @@ -448,14 +472,14 @@ class FlaxModelTesterMixin: continue model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname) - base_params = flatten_dict(unfreeze(base_model.params)) + base_params = get_params(base_model.params) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() @@ -471,7 +495,7 @@ class FlaxModelTesterMixin: continue model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) + base_params = get_params(model.params) # convert Flax model to PyTorch model pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning @@ -484,7 +508,7 @@ class FlaxModelTesterMixin: pt_model.save_pretrained(tmpdirname) head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix) for key in base_param_from_head.keys(): max_diff = (base_params[key] - base_param_from_head[key]).sum().item() @@ -500,7 +524,7 @@ class FlaxModelTesterMixin: continue model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix) # convert Flax model to PyTorch model pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning @@ -512,7 +536,7 @@ class FlaxModelTesterMixin: pt_model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - base_params = flatten_dict(unfreeze(base_model.params)) + base_params = get_params(base_model.params) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() @@ -529,7 +553,7 @@ class FlaxModelTesterMixin: model = model_class(config) model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix) # convert Flax model to PyTorch model pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning @@ -541,7 +565,7 @@ class FlaxModelTesterMixin: pt_model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - base_params = flatten_dict(unfreeze(base_model.params)) + base_params = get_params(base_model.params) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item()