[Flax] adding support for batch norm layers (#21581)
* [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * cleanup, batchnorm support in sharded pt to flax * support for batchnorm tests in pt+flax integration * simplifying checking batch norm layer
This commit is contained in:
@@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
def get_params(params, from_head_prefix=None):
|
||||
"""Function extracts relevant parameters into flatten dict from model params,
|
||||
appends batch normalization statistics if present"""
|
||||
|
||||
# If Both parameters and batch normalization statistics are present
|
||||
if "batch_stats" in params:
|
||||
# Extract only parameters for the specified head prefix (if specified) and add batch statistics
|
||||
if from_head_prefix is not None:
|
||||
extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix]))
|
||||
extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix]))
|
||||
else:
|
||||
extracted_params = flatten_dict(unfreeze(params["params"]))
|
||||
extracted_params.update(flatten_dict(params["batch_stats"]))
|
||||
|
||||
# Only parameters are present
|
||||
else:
|
||||
if from_head_prefix is not None:
|
||||
extracted_params = flatten_dict(unfreeze(params[from_head_prefix]))
|
||||
else:
|
||||
extracted_params = flatten_dict(unfreeze(params))
|
||||
|
||||
return extracted_params
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
@@ -426,14 +450,14 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
base_params = get_params(model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
@@ -448,14 +472,14 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
@@ -471,7 +495,7 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
base_params = get_params(model.params)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
@@ -484,7 +508,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
@@ -500,7 +524,7 @@ class FlaxModelTesterMixin:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
@@ -512,7 +536,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
@@ -529,7 +553,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
@@ -541,7 +565,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user