[Flax] adding support for batch norm layers (#21581)

* [flax] adding support for batch norm layers

* fixing bugs related to pt+flax integration

* cleanup, batchnorm support in sharded pt to flax

* support for batchnorm tests in pt+flax integration

* simplifying checking batch norm layer
This commit is contained in:
Shubhamai
2023-02-24 13:17:33 +05:30
committed by GitHub
parent 279008adc3
commit f7ca656f07
3 changed files with 149 additions and 28 deletions

View File

@@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None):
return attn_mask
def get_params(params, from_head_prefix=None):
"""Function extracts relevant parameters into flatten dict from model params,
appends batch normalization statistics if present"""
# If Both parameters and batch normalization statistics are present
if "batch_stats" in params:
# Extract only parameters for the specified head prefix (if specified) and add batch statistics
if from_head_prefix is not None:
extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix]))
extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix]))
else:
extracted_params = flatten_dict(unfreeze(params["params"]))
extracted_params.update(flatten_dict(params["batch_stats"]))
# Only parameters are present
else:
if from_head_prefix is not None:
extracted_params = flatten_dict(unfreeze(params[from_head_prefix]))
else:
extracted_params = flatten_dict(unfreeze(params))
return extracted_params
@require_flax
class FlaxModelTesterMixin:
model_tester = None
@@ -426,14 +450,14 @@ class FlaxModelTesterMixin:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
base_params = get_params(model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
@@ -448,14 +472,14 @@ class FlaxModelTesterMixin:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params))
base_params = get_params(base_model.params)
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
@@ -471,7 +495,7 @@ class FlaxModelTesterMixin:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
base_params = get_params(model.params)
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
@@ -484,7 +508,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
@@ -500,7 +524,7 @@ class FlaxModelTesterMixin:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
@@ -512,7 +536,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
base_params = get_params(base_model.params)
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
@@ -529,7 +553,7 @@ class FlaxModelTesterMixin:
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
@@ -541,7 +565,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
base_params = get_params(base_model.params)
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()