fix loading flax bf16 weights in pt (#14369)
* fix loading flax bf16 weights in pt * fix clip test * fix t5 test * add logging statement * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * switch back to native any * fix check for bf16 weights Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -227,6 +227,11 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
@@ -332,6 +337,11 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
||||
@@ -384,6 +384,35 @@ class FlaxModelTesterMixin:
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -430,6 +430,36 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
Reference in New Issue
Block a user