[Flax] adding support for batch norm layers (#21581)
* [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * cleanup, batchnorm support in sharded pt to flax * support for batchnorm tests in pt+flax integration * simplifying checking batch norm layer
This commit is contained in:
@@ -83,6 +83,16 @@ def rename_key_and_reshape_tensor(
|
||||
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# batch norm layer mean
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
|
||||
if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# batch norm layer var
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
|
||||
if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# embedding
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||
@@ -118,13 +128,25 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||
|
||||
# use params dict if the model contains batch norm layers
|
||||
if "params" in flax_model.params:
|
||||
flax_model_params = flax_model.params["params"]
|
||||
else:
|
||||
flax_model_params = flax_model.params
|
||||
random_flax_state_dict = flatten_dict(flax_model_params)
|
||||
|
||||
# add batch_stats keys,values to dict
|
||||
if "batch_stats" in flax_model.params:
|
||||
flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
|
||||
random_flax_state_dict.update(flax_batch_stats)
|
||||
|
||||
flax_state_dict = {}
|
||||
|
||||
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
||||
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
|
||||
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||
)
|
||||
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
||||
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
|
||||
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||
)
|
||||
|
||||
@@ -154,6 +176,20 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
|
||||
# add batch stats if the model contains batchnorm layers
|
||||
if "batch_stats" in flax_model.params:
|
||||
if "mean" in flax_key[-1] or "var" in flax_key[-1]:
|
||||
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
|
||||
continue
|
||||
# remove num_batches_tracked key
|
||||
if "num_batches_tracked" in flax_key[-1]:
|
||||
flax_state_dict.pop(flax_key, None)
|
||||
continue
|
||||
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
|
||||
|
||||
else:
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||
|
||||
@@ -176,12 +212,21 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||
|
||||
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
||||
# use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
|
||||
if "batch_stats" in flax_model.params:
|
||||
flax_model_params = flax_model.params["params"]
|
||||
|
||||
random_flax_state_dict = flatten_dict(flax_model_params)
|
||||
random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
|
||||
else:
|
||||
flax_model_params = flax_model.params
|
||||
random_flax_state_dict = flatten_dict(flax_model_params)
|
||||
|
||||
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
|
||||
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||
)
|
||||
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
||||
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
|
||||
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||
)
|
||||
# Need to change some parameters name to match Flax names
|
||||
@@ -209,6 +254,23 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
|
||||
# add batch stats if the model contains batchnorm layers
|
||||
if "batch_stats" in flax_model.params:
|
||||
if "mean" in flax_key[-1]:
|
||||
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
|
||||
continue
|
||||
if "var" in flax_key[-1]:
|
||||
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
|
||||
continue
|
||||
# remove num_batches_tracked key
|
||||
if "num_batches_tracked" in flax_key[-1]:
|
||||
flax_state_dict.pop(flax_key, None)
|
||||
continue
|
||||
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
|
||||
|
||||
else:
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||
return unflatten_dict(flax_state_dict)
|
||||
@@ -299,6 +361,15 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
elif flax_key_tuple[-1] in ["scale", "embedding"]:
|
||||
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
||||
|
||||
# adding batch stats from flax batch norm to pt
|
||||
elif "mean" in flax_key_tuple[-1]:
|
||||
flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
|
||||
elif "var" in flax_key_tuple[-1]:
|
||||
flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)
|
||||
|
||||
if "batch_stats" in flax_state:
|
||||
flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header
|
||||
else:
|
||||
flax_key = ".".join(flax_key_tuple)
|
||||
|
||||
if flax_key in pt_model_dict:
|
||||
|
||||
@@ -837,6 +837,27 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
# keep the params on CPU if we don't want to initialize
|
||||
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
||||
|
||||
if "batch_stats" in state: # if flax model contains batch norm layers
|
||||
# if model is base model only use model_prefix key
|
||||
if (
|
||||
cls.base_model_prefix not in dict(model.params_shape_tree["params"])
|
||||
and cls.base_model_prefix in state["params"]
|
||||
):
|
||||
state["params"] = state["params"][cls.base_model_prefix]
|
||||
state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
|
||||
|
||||
# if model is head model and we are loading weights from base model
|
||||
# we initialize new params dict with base_model_prefix
|
||||
if (
|
||||
cls.base_model_prefix in dict(model.params_shape_tree["params"])
|
||||
and cls.base_model_prefix not in state["params"]
|
||||
):
|
||||
state = {
|
||||
"params": {cls.base_model_prefix: state["params"]},
|
||||
"batch_stats": {cls.base_model_prefix: state["batch_stats"]},
|
||||
}
|
||||
|
||||
else:
|
||||
# if model is base model only use model_prefix key
|
||||
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
|
||||
state = state[cls.base_model_prefix]
|
||||
@@ -854,6 +875,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
missing_keys = model.required_params - set(state.keys())
|
||||
unexpected_keys = set(state.keys()) - model.required_params
|
||||
|
||||
# Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
|
||||
for unexpected_key in unexpected_keys.copy():
|
||||
if "num_batches_tracked" in unexpected_key[-1]:
|
||||
unexpected_keys.remove(unexpected_key)
|
||||
|
||||
if missing_keys and not _do_init:
|
||||
logger.warning(
|
||||
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
||||
|
||||
@@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
def get_params(params, from_head_prefix=None):
|
||||
"""Function extracts relevant parameters into flatten dict from model params,
|
||||
appends batch normalization statistics if present"""
|
||||
|
||||
# If Both parameters and batch normalization statistics are present
|
||||
if "batch_stats" in params:
|
||||
# Extract only parameters for the specified head prefix (if specified) and add batch statistics
|
||||
if from_head_prefix is not None:
|
||||
extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix]))
|
||||
extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix]))
|
||||
else:
|
||||
extracted_params = flatten_dict(unfreeze(params["params"]))
|
||||
extracted_params.update(flatten_dict(params["batch_stats"]))
|
||||
|
||||
# Only parameters are present
|
||||
else:
|
||||
if from_head_prefix is not None:
|
||||
extracted_params = flatten_dict(unfreeze(params[from_head_prefix]))
|
||||
else:
|
||||
extracted_params = flatten_dict(unfreeze(params))
|
||||
|
||||
return extracted_params
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
@@ -426,14 +450,14 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
base_params = get_params(model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
@@ -448,14 +472,14 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
@@ -471,7 +495,7 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
base_params = get_params(model.params)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
@@ -484,7 +508,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
@@ -500,7 +524,7 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
@@ -512,7 +536,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
@@ -529,7 +553,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
@@ -541,7 +565,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user