Load sharded pt to flax (#18419)
* initial commit * add small test * add cross pt tf flag to test * fix quality * style * update test with new repo * fix failing test * update * fix wrong param ordering * style * update based on review * update related to recent new caching mechanism * quality * Update based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> * quality and style * Update src/transformers/modeling_flax_utils.py Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -38,7 +38,9 @@ logger = logging.get_logger(__name__)
|
|||||||
#####################
|
#####################
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
|
def load_pytorch_checkpoint_in_flax_state_dict(
|
||||||
|
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
|
||||||
|
):
|
||||||
"""Load pytorch checkpoints in a flax model"""
|
"""Load pytorch checkpoints in a flax model"""
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
@@ -50,6 +52,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if not is_sharded:
|
||||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||||
|
|
||||||
@@ -57,7 +60,9 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
|||||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||||
|
|
||||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||||
|
else:
|
||||||
|
# model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
|
||||||
|
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
|
||||||
return flax_state_dict
|
return flax_state_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
return unflatten_dict(flax_state_dict)
|
return unflatten_dict(flax_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# Sharded Pytorch => Flax #
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Load the index
|
||||||
|
flax_state_dict = {}
|
||||||
|
for shard_file in shard_filenames:
|
||||||
|
# load using msgpack utils
|
||||||
|
pt_state_dict = torch.load(shard_file)
|
||||||
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
|
||||||
|
model_prefix = flax_model.base_model_prefix
|
||||||
|
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||||
|
|
||||||
|
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
||||||
|
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||||
|
)
|
||||||
|
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
||||||
|
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||||
|
)
|
||||||
|
# Need to change some parameters name to match Flax names
|
||||||
|
for pt_key, pt_tensor in pt_state_dict.items():
|
||||||
|
|
||||||
|
pt_tuple_key = tuple(pt_key.split("."))
|
||||||
|
|
||||||
|
# remove base model prefix if necessary
|
||||||
|
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
||||||
|
if load_model_with_head_into_base_model and has_base_model_prefix:
|
||||||
|
pt_tuple_key = pt_tuple_key[1:]
|
||||||
|
|
||||||
|
# Correctly rename weight parameters
|
||||||
|
flax_key, flax_tensor = rename_key_and_reshape_tensor(
|
||||||
|
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
|
||||||
|
)
|
||||||
|
# add model prefix if necessary
|
||||||
|
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
|
||||||
|
if load_base_model_into_model_with_head and require_base_model_prefix:
|
||||||
|
flax_key = (model_prefix,) + flax_key
|
||||||
|
|
||||||
|
if flax_key in random_flax_state_dict:
|
||||||
|
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
||||||
|
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# also add unexpected weight so that warning is thrown
|
||||||
|
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
return unflatten_dict(flax_state_dict)
|
||||||
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# Flax => PyTorch #
|
# Flax => PyTorch #
|
||||||
#####################
|
#####################
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
FLAX_WEIGHTS_INDEX_NAME,
|
FLAX_WEIGHTS_INDEX_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -650,6 +651,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||||
|
# Load from a sharded pytorch checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||||
# Load from a Flax checkpoint
|
# Load from a Flax checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||||
@@ -700,6 +705,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
|
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
||||||
|
elif resolved_archive_file is None and from_pt:
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
if resolved_archive_file is not None:
|
||||||
|
is_sharded = True
|
||||||
if resolved_archive_file is None:
|
if resolved_archive_file is None:
|
||||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||||
# message.
|
# message.
|
||||||
@@ -714,6 +726,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||||
" load this model from those weights."
|
" load this model from those weights."
|
||||||
)
|
)
|
||||||
|
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
|
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
|
||||||
|
" `from_pt=True` to load this model from those weights."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
@@ -761,7 +779,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
||||||
|
|
||||||
if from_pt:
|
if from_pt:
|
||||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
|
|||||||
@@ -1099,6 +1099,14 @@ class FlaxModelTesterMixin:
|
|||||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_from_sharded_pt(self):
|
||||||
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||||
|
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
|
||||||
|
for key, ref_val in flatten_dict(ref_model.params).items():
|
||||||
|
val = flatten_dict(model.params)[key]
|
||||||
|
assert np.allclose(np.array(val), np.array(ref_val))
|
||||||
|
|
||||||
def test_gradient_checkpointing(self):
|
def test_gradient_checkpointing(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user