cross platform from_pretrained (#20538)
* add support for `from_pt` * add tf_flax utility file * Update src/transformers/modeling_tf_flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove flax related modifications * add test * remove FLAX related commits * fixup * remove safetensor todos * revert deletion Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from .utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ModelOutput,
|
||||
@@ -2392,7 +2393,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
save directory.
|
||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||
configuration JSON file named *config.json* is found in the directory.
|
||||
from_pt: (`bool`, *optional*, defaults to `False`):
|
||||
from_pt (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a PyTorch state_dict save file (see docstring of
|
||||
`pretrained_model_name_or_path` argument).
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||
@@ -2531,7 +2532,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if is_local:
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint in priority if from_pt
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
@@ -2559,7 +2560,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||
@@ -2630,6 +2633,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
@@ -2646,8 +2656,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
||||
f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
except EnvironmentError:
|
||||
@@ -2661,7 +2671,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
|
||||
)
|
||||
if is_local:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
|
||||
@@ -2127,6 +2127,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_hub_from_pt(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
# the model above is the same as the model below, just a sharded pytorch version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = tf.keras.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user