Add variant to transformers (#21332)
* Bump onnx in /examples/research_projects/decision_transformer Bumps [onnx](https://github.com/onnx/onnx) from 1.11.0 to 1.13.0. - [Release notes](https://github.com/onnx/onnx/releases) - [Changelog](https://github.com/onnx/onnx/blob/main/docs/Changelog.md) - [Commits](https://github.com/onnx/onnx/compare/v1.11.0...v1.13.0) --- updated-dependencies: - dependency-name: onnx dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> * adapt * finish * Update examples/research_projects/decision_transformer/requirements.txt * up * add tests * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * fix test --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
bc44e947f3
commit
90cddfa824
@@ -667,6 +667,15 @@ def _load_state_dict_into_meta_model(
|
|||||||
return error_msgs, offload_index, state_dict_index
|
return error_msgs, offload_index, state_dict_index
|
||||||
|
|
||||||
|
|
||||||
|
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||||
|
if variant is not None:
|
||||||
|
splits = weights_name.split(".")
|
||||||
|
splits = splits[:-1] + [variant] + splits[-1:]
|
||||||
|
weights_name = ".".join(splits)
|
||||||
|
|
||||||
|
return weights_name
|
||||||
|
|
||||||
|
|
||||||
class ModuleUtilsMixin:
|
class ModuleUtilsMixin:
|
||||||
"""
|
"""
|
||||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||||
@@ -1567,6 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
max_shard_size: Union[int, str] = "10GB",
|
max_shard_size: Union[int, str] = "10GB",
|
||||||
safe_serialization: bool = False,
|
safe_serialization: bool = False,
|
||||||
|
variant: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1604,6 +1614,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||||
|
variant (`str`, *optional*):
|
||||||
|
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
@@ -1675,6 +1687,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Shard the model if it is too big.
|
# Shard the model if it is too big.
|
||||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||||
|
weights_name = _add_variant(weights_name, variant)
|
||||||
|
|
||||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||||
|
|
||||||
# Clean the folder from a previous save
|
# Clean the folder from a previous save
|
||||||
@@ -1701,10 +1715,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
save_function(shard, os.path.join(save_directory, shard_file))
|
save_function(shard, os.path.join(save_directory, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant))
|
||||||
|
logger.info(f"Model weights saved in {path_to_weights}")
|
||||||
else:
|
else:
|
||||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||||
save_index_file = os.path.join(save_directory, save_index_file)
|
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||||
# Save the index as well
|
# Save the index as well
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
@@ -1931,6 +1946,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
subfolder (`str`, *optional*, defaults to `""`):
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||||
specify the folder name here.
|
specify the folder name here.
|
||||||
|
variant (`str`, *optional*):
|
||||||
|
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||||
|
ignored when using `from_tf` or `from_flax`.
|
||||||
|
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
@@ -2017,6 +2035,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
|
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
variant = kwargs.pop("variant", None)
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -2132,42 +2151,57 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Load from a Flax checkpoint in priority if from_flax
|
# Load from a Flax checkpoint in priority if from_flax
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||||
elif is_safetensors_available() and os.path.isfile(
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
|
||||||
):
|
):
|
||||||
# Load from a safetensors checkpoint
|
# Load from a safetensors checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||||
|
)
|
||||||
elif is_safetensors_available() and os.path.isfile(
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
# Load from a sharded safetensors checkpoint
|
# Load from a sharded safetensors checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||||
|
)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
elif os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||||
|
):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
archive_file = os.path.join(
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)):
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||||
|
)
|
||||||
|
elif os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||||
|
):
|
||||||
# Load from a sharded PyTorch checkpoint
|
# Load from a sharded PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
|
||||||
|
)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.isfile(
|
elif os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
|
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
|
||||||
"weights."
|
" `from_tf=True` to load this model from those weights."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
|
||||||
"weights."
|
" to load this model from those weights."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME},"
|
||||||
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
|
||||||
|
f" {pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
@@ -2190,9 +2224,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
elif from_flax:
|
elif from_flax:
|
||||||
filename = FLAX_WEIGHTS_NAME
|
filename = FLAX_WEIGHTS_NAME
|
||||||
elif is_safetensors_available():
|
elif is_safetensors_available():
|
||||||
filename = SAFE_WEIGHTS_NAME
|
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
@@ -2213,23 +2247,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||||
# result when internet is up, the repo and revision exist, but the file does not.
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
pretrained_model_name_or_path,
|
||||||
|
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
|
||||||
|
**cached_file_kwargs,
|
||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
else:
|
else:
|
||||||
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||||
filename = WEIGHTS_NAME
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs
|
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
||||||
)
|
)
|
||||||
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
pretrained_model_name_or_path,
|
||||||
|
_add_variant(WEIGHTS_INDEX_NAME, variant),
|
||||||
|
**cached_file_kwargs,
|
||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
@@ -2244,19 +2282,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {WEIGHTS_NAME} but there is a file for TensorFlow weights. Use `from_tf=True` to"
|
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
|
||||||
" load this model from those weights."
|
" Use `from_tf=True` to load this model from those weights."
|
||||||
)
|
)
|
||||||
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {WEIGHTS_NAME} but there is a file for Flax weights. Use `from_flax=True` to load"
|
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
|
||||||
" this model from those weights."
|
" `from_flax=True` to load this model from those weights."
|
||||||
|
)
|
||||||
|
elif variant is not None and has_file(
|
||||||
|
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
||||||
|
):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
|
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
||||||
|
f" {variant}. Use `variant=None` to load this model from those weights."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
||||||
|
f" {FLAX_WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
@@ -2268,8 +2315,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
|
||||||
f" {FLAX_WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_local:
|
if is_local:
|
||||||
|
|||||||
@@ -2958,6 +2958,138 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_variant_local(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
|
|
||||||
|
weights_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_variant_local_sharded(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
|
||||||
|
|
||||||
|
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||||
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
for i in range(1, 6):
|
||||||
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
|
||||||
|
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_name_file))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_local_safe(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)
|
||||||
|
|
||||||
|
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])
|
||||||
|
|
||||||
|
weights_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_local_sharded_safe(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)
|
||||||
|
|
||||||
|
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||||
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
for i in range(1, 6):
|
||||||
|
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
|
||||||
|
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_name_file))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_variant_hub(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_checkpoint_variant_hub_sharded(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
|
||||||
|
)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_hub_safe(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_hub_sharded_safe(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
|
||||||
|
)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
||||||
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
||||||
|
|||||||
Reference in New Issue
Block a user