[Flax] adding support for batch norm layers (#21581)
* [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * cleanup, batchnorm support in sharded pt to flax * support for batchnorm tests in pt+flax integration * simplifying checking batch norm layer
This commit is contained in:
@@ -83,6 +83,16 @@ def rename_key_and_reshape_tensor(
|
|||||||
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||||
return renamed_pt_tuple_key, pt_tensor
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# batch norm layer mean
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
|
||||||
|
if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# batch norm layer var
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
|
||||||
|
if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
# embedding
|
# embedding
|
||||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||||
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||||
@@ -118,13 +128,25 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
|
||||||
model_prefix = flax_model.base_model_prefix
|
model_prefix = flax_model.base_model_prefix
|
||||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
|
||||||
|
# use params dict if the model contains batch norm layers
|
||||||
|
if "params" in flax_model.params:
|
||||||
|
flax_model_params = flax_model.params["params"]
|
||||||
|
else:
|
||||||
|
flax_model_params = flax_model.params
|
||||||
|
random_flax_state_dict = flatten_dict(flax_model_params)
|
||||||
|
|
||||||
|
# add batch_stats keys,values to dict
|
||||||
|
if "batch_stats" in flax_model.params:
|
||||||
|
flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
|
||||||
|
random_flax_state_dict.update(flax_batch_stats)
|
||||||
|
|
||||||
flax_state_dict = {}
|
flax_state_dict = {}
|
||||||
|
|
||||||
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
|
||||||
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
|
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||||
)
|
)
|
||||||
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
|
||||||
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
|
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -154,6 +176,20 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add batch stats if the model contains batchnorm layers
|
||||||
|
if "batch_stats" in flax_model.params:
|
||||||
|
if "mean" in flax_key[-1] or "var" in flax_key[-1]:
|
||||||
|
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
continue
|
||||||
|
# remove num_batches_tracked key
|
||||||
|
if "num_batches_tracked" in flax_key[-1]:
|
||||||
|
flax_state_dict.pop(flax_key, None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# also add unexpected weight so that warning is thrown
|
||||||
|
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
|
||||||
|
else:
|
||||||
# also add unexpected weight so that warning is thrown
|
# also add unexpected weight so that warning is thrown
|
||||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
|
||||||
@@ -176,12 +212,21 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
|||||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
|
||||||
model_prefix = flax_model.base_model_prefix
|
model_prefix = flax_model.base_model_prefix
|
||||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
|
||||||
|
|
||||||
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
# use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
|
||||||
|
if "batch_stats" in flax_model.params:
|
||||||
|
flax_model_params = flax_model.params["params"]
|
||||||
|
|
||||||
|
random_flax_state_dict = flatten_dict(flax_model_params)
|
||||||
|
random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
|
||||||
|
else:
|
||||||
|
flax_model_params = flax_model.params
|
||||||
|
random_flax_state_dict = flatten_dict(flax_model_params)
|
||||||
|
|
||||||
|
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
|
||||||
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
|
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||||
)
|
)
|
||||||
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
|
||||||
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
|
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
|
||||||
)
|
)
|
||||||
# Need to change some parameters name to match Flax names
|
# Need to change some parameters name to match Flax names
|
||||||
@@ -209,6 +254,23 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
|||||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add batch stats if the model contains batchnorm layers
|
||||||
|
if "batch_stats" in flax_model.params:
|
||||||
|
if "mean" in flax_key[-1]:
|
||||||
|
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
continue
|
||||||
|
if "var" in flax_key[-1]:
|
||||||
|
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
continue
|
||||||
|
# remove num_batches_tracked key
|
||||||
|
if "num_batches_tracked" in flax_key[-1]:
|
||||||
|
flax_state_dict.pop(flax_key, None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# also add unexpected weight so that warning is thrown
|
||||||
|
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
|
||||||
|
else:
|
||||||
# also add unexpected weight so that warning is thrown
|
# also add unexpected weight so that warning is thrown
|
||||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||||
return unflatten_dict(flax_state_dict)
|
return unflatten_dict(flax_state_dict)
|
||||||
@@ -299,6 +361,15 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|||||||
elif flax_key_tuple[-1] in ["scale", "embedding"]:
|
elif flax_key_tuple[-1] in ["scale", "embedding"]:
|
||||||
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
||||||
|
|
||||||
|
# adding batch stats from flax batch norm to pt
|
||||||
|
elif "mean" in flax_key_tuple[-1]:
|
||||||
|
flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
|
||||||
|
elif "var" in flax_key_tuple[-1]:
|
||||||
|
flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)
|
||||||
|
|
||||||
|
if "batch_stats" in flax_state:
|
||||||
|
flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header
|
||||||
|
else:
|
||||||
flax_key = ".".join(flax_key_tuple)
|
flax_key = ".".join(flax_key_tuple)
|
||||||
|
|
||||||
if flax_key in pt_model_dict:
|
if flax_key in pt_model_dict:
|
||||||
|
|||||||
@@ -837,6 +837,27 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
# keep the params on CPU if we don't want to initialize
|
# keep the params on CPU if we don't want to initialize
|
||||||
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
||||||
|
|
||||||
|
if "batch_stats" in state: # if flax model contains batch norm layers
|
||||||
|
# if model is base model only use model_prefix key
|
||||||
|
if (
|
||||||
|
cls.base_model_prefix not in dict(model.params_shape_tree["params"])
|
||||||
|
and cls.base_model_prefix in state["params"]
|
||||||
|
):
|
||||||
|
state["params"] = state["params"][cls.base_model_prefix]
|
||||||
|
state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
|
||||||
|
|
||||||
|
# if model is head model and we are loading weights from base model
|
||||||
|
# we initialize new params dict with base_model_prefix
|
||||||
|
if (
|
||||||
|
cls.base_model_prefix in dict(model.params_shape_tree["params"])
|
||||||
|
and cls.base_model_prefix not in state["params"]
|
||||||
|
):
|
||||||
|
state = {
|
||||||
|
"params": {cls.base_model_prefix: state["params"]},
|
||||||
|
"batch_stats": {cls.base_model_prefix: state["batch_stats"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
# if model is base model only use model_prefix key
|
# if model is base model only use model_prefix key
|
||||||
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
|
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
|
||||||
state = state[cls.base_model_prefix]
|
state = state[cls.base_model_prefix]
|
||||||
@@ -854,6 +875,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
missing_keys = model.required_params - set(state.keys())
|
missing_keys = model.required_params - set(state.keys())
|
||||||
unexpected_keys = set(state.keys()) - model.required_params
|
unexpected_keys = set(state.keys()) - model.required_params
|
||||||
|
|
||||||
|
# Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
|
||||||
|
for unexpected_key in unexpected_keys.copy():
|
||||||
|
if "num_batches_tracked" in unexpected_key[-1]:
|
||||||
|
unexpected_keys.remove(unexpected_key)
|
||||||
|
|
||||||
if missing_keys and not _do_init:
|
if missing_keys and not _do_init:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
||||||
|
|||||||
@@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None):
|
|||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_params(params, from_head_prefix=None):
|
||||||
|
"""Function extracts relevant parameters into flatten dict from model params,
|
||||||
|
appends batch normalization statistics if present"""
|
||||||
|
|
||||||
|
# If Both parameters and batch normalization statistics are present
|
||||||
|
if "batch_stats" in params:
|
||||||
|
# Extract only parameters for the specified head prefix (if specified) and add batch statistics
|
||||||
|
if from_head_prefix is not None:
|
||||||
|
extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix]))
|
||||||
|
extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix]))
|
||||||
|
else:
|
||||||
|
extracted_params = flatten_dict(unfreeze(params["params"]))
|
||||||
|
extracted_params.update(flatten_dict(params["batch_stats"]))
|
||||||
|
|
||||||
|
# Only parameters are present
|
||||||
|
else:
|
||||||
|
if from_head_prefix is not None:
|
||||||
|
extracted_params = flatten_dict(unfreeze(params[from_head_prefix]))
|
||||||
|
else:
|
||||||
|
extracted_params = flatten_dict(unfreeze(params))
|
||||||
|
|
||||||
|
return extracted_params
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxModelTesterMixin:
|
class FlaxModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
@@ -426,14 +450,14 @@ class FlaxModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model = base_class(config)
|
model = base_class(config)
|
||||||
base_params = flatten_dict(unfreeze(model.params))
|
base_params = get_params(model.params)
|
||||||
|
|
||||||
# check that all base model weights are loaded correctly
|
# check that all base model weights are loaded correctly
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
head_model = model_class.from_pretrained(tmpdirname)
|
head_model = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||||
|
|
||||||
for key in base_param_from_head.keys():
|
for key in base_param_from_head.keys():
|
||||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||||
@@ -448,14 +472,14 @@ class FlaxModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||||
|
|
||||||
# check that all base model weights are loaded correctly
|
# check that all base model weights are loaded correctly
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
base_model = base_class.from_pretrained(tmpdirname)
|
base_model = base_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(base_model.params))
|
base_params = get_params(base_model.params)
|
||||||
|
|
||||||
for key in base_params_from_head.keys():
|
for key in base_params_from_head.keys():
|
||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
@@ -471,7 +495,7 @@ class FlaxModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model = base_class(config)
|
model = base_class(config)
|
||||||
base_params = flatten_dict(unfreeze(model.params))
|
base_params = get_params(model.params)
|
||||||
|
|
||||||
# convert Flax model to PyTorch model
|
# convert Flax model to PyTorch model
|
||||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
@@ -484,7 +508,7 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||||
|
|
||||||
for key in base_param_from_head.keys():
|
for key in base_param_from_head.keys():
|
||||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||||
@@ -500,7 +524,7 @@ class FlaxModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||||
|
|
||||||
# convert Flax model to PyTorch model
|
# convert Flax model to PyTorch model
|
||||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
@@ -512,7 +536,7 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(base_model.params))
|
base_params = get_params(base_model.params)
|
||||||
|
|
||||||
for key in base_params_from_head.keys():
|
for key in base_params_from_head.keys():
|
||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
@@ -529,7 +553,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.params = model.to_bf16(model.params)
|
model.params = model.to_bf16(model.params)
|
||||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||||
|
|
||||||
# convert Flax model to PyTorch model
|
# convert Flax model to PyTorch model
|
||||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
@@ -541,7 +565,7 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(base_model.params))
|
base_params = get_params(base_model.params)
|
||||||
|
|
||||||
for key in base_params_from_head.keys():
|
for key in base_params_from_head.keys():
|
||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
|
|||||||
Reference in New Issue
Block a user