[Flax] Correct pt to flax conversion if from base to head (#13006)
* finish PR * add tests * correct tests * finish * correct other flax tests * better naming * correct naming * finish * apply sylvains suggestions
This commit is contained in:
committed by
GitHub
parent
33929448a1
commit
60e448c87e
@@ -17,6 +17,7 @@
|
||||
|
||||
import os
|
||||
from pickle import UnpicklingError
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -58,60 +59,97 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
||||
return flax_state_dict
|
||||
|
||||
|
||||
def rename_key_and_reshape_tensor(
|
||||
pt_tuple_key: Tuple[str],
|
||||
pt_tensor: np.ndarray,
|
||||
random_flax_state_dict: Dict[str, jnp.ndarray],
|
||||
model_prefix: str,
|
||||
) -> (Tuple[str], np.ndarray):
|
||||
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
||||
|
||||
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
|
||||
"""Checks if ``key`` of ``(prefix,) + key`` is in random_flax_state_dict"""
|
||||
return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0
|
||||
|
||||
# layer norm
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# embedding
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# conv layer
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# linear layer
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||
pt_tensor = pt_tensor.T
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# old PyTorch layer norm weight
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||
if pt_tuple_key[-1] == "gamma":
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# old PyTorch layer norm bias
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||
if pt_tuple_key[-1] == "beta":
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
return pt_tuple_key, pt_tensor
|
||||
|
||||
|
||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
# convert pytorch tensor to numpy
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||
flax_state_dict = {}
|
||||
|
||||
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
||||
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
||||
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
)
|
||||
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
||||
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
||||
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
# Need to change some parameters name to match Flax names
|
||||
for pt_key, pt_tensor in pt_state_dict.items():
|
||||
|
||||
pt_tuple_key = tuple(pt_key.split("."))
|
||||
|
||||
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
||||
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
||||
|
||||
if remove_base_model_prefix and has_base_model_prefix:
|
||||
# remove base model prefix if necessary
|
||||
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
||||
if load_model_with_head_into_base_model and has_base_model_prefix:
|
||||
pt_tuple_key = pt_tuple_key[1:]
|
||||
elif add_base_model_prefix and require_base_model_prefix:
|
||||
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
||||
|
||||
# Correctly rename weight parameters
|
||||
if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
|
||||
# conv layer
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
||||
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
||||
# linear layer
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
pt_tensor = pt_tensor.T
|
||||
elif pt_tuple_key[-1] == "gamma":
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||
elif pt_tuple_key[-1] == "beta":
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||
flax_key, flax_tensor = rename_key_and_reshape_tensor(
|
||||
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
|
||||
)
|
||||
|
||||
if pt_tuple_key in random_flax_state_dict:
|
||||
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
||||
# add model prefix if necessary
|
||||
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
|
||||
if load_base_model_into_model_with_head and require_base_model_prefix:
|
||||
flax_key = (model_prefix,) + flax_key
|
||||
|
||||
if flax_key in random_flax_state_dict:
|
||||
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
||||
raise ValueError(
|
||||
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
||||
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||
|
||||
return unflatten_dict(flax_state_dict)
|
||||
|
||||
@@ -154,10 +192,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
flax_state_dict = flatten_dict(flax_state)
|
||||
pt_model_dict = pt_model.state_dict()
|
||||
|
||||
remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and (
|
||||
load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
|
||||
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||
)
|
||||
add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and (
|
||||
load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
|
||||
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||
)
|
||||
|
||||
@@ -170,9 +208,9 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
|
||||
|
||||
# adapt flax_key to prepare for loading from/to base model only
|
||||
if remove_base_model_prefix and has_base_model_prefix:
|
||||
if load_model_with_head_into_base_model and has_base_model_prefix:
|
||||
flax_key_tuple = flax_key_tuple[1:]
|
||||
elif add_base_model_prefix and require_base_model_prefix:
|
||||
elif load_base_model_into_model_with_head and require_base_model_prefix:
|
||||
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
|
||||
|
||||
# rename flax weights to PyTorch format
|
||||
|
||||
@@ -213,9 +213,20 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_from_base(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
def test_save_load_to_base(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
@@ -307,9 +318,20 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_from_base(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
def test_save_load_to_base(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
||||
@@ -334,6 +334,63 @@ class FlaxModelTesterMixin:
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -17,8 +17,15 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
@@ -40,6 +47,7 @@ if is_flax_available():
|
||||
from flax.training.common_utils import onehot
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
|
||||
|
||||
|
||||
@@ -363,6 +371,65 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
Reference in New Issue
Block a user