cross platform from_pretrained (#20538)
* add support for `from_pt` * add tf_flax utility file * Update src/transformers/modeling_tf_flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove flax related modifications * add test * remove FLAX related commits * fixup * remove safetensor todos * revert deletion Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from .utils import (
|
|||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
TF2_WEIGHTS_INDEX_NAME,
|
TF2_WEIGHTS_INDEX_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
|
TF_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -2392,7 +2393,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
save directory.
|
save directory.
|
||||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||||
configuration JSON file named *config.json* is found in the directory.
|
configuration JSON file named *config.json* is found in the directory.
|
||||||
from_pt: (`bool`, *optional*, defaults to `False`):
|
from_pt (`bool`, *optional*, defaults to `False`):
|
||||||
Load the model weights from a PyTorch state_dict save file (see docstring of
|
Load the model weights from a PyTorch state_dict save file (see docstring of
|
||||||
`pretrained_model_name_or_path` argument).
|
`pretrained_model_name_or_path` argument).
|
||||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||||
@@ -2531,7 +2532,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if is_local:
|
||||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint in priority if from_pt
|
# Load from a PyTorch checkpoint in priority if from_pt
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
@@ -2559,7 +2560,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||||
|
):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||||
@@ -2630,6 +2633,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
|
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||||
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
if resolved_archive_file is not None:
|
||||||
|
is_sharded = True
|
||||||
if resolved_archive_file is None:
|
if resolved_archive_file is None:
|
||||||
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
|
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
|
||||||
# message.
|
# message.
|
||||||
@@ -2646,8 +2656,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
||||||
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
@@ -2661,7 +2671,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
|
||||||
)
|
)
|
||||||
if is_local:
|
if is_local:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
|||||||
@@ -2127,6 +2127,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
assert np.allclose(p1.numpy(), p2.numpy())
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_checkpoint_sharding_hub_from_pt(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||||
|
# the model above is the same as the model below, just a sharded pytorch version.
|
||||||
|
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
def test_shard_checkpoint(self):
|
def test_shard_checkpoint(self):
|
||||||
# This is the model we will use, total size 340,000 bytes.
|
# This is the model we will use, total size 340,000 bytes.
|
||||||
model = tf.keras.Sequential(
|
model = tf.keras.Sequential(
|
||||||
|
|||||||
Reference in New Issue
Block a user