Add variant to transformers (#21332)
* Bump onnx in /examples/research_projects/decision_transformer Bumps [onnx](https://github.com/onnx/onnx) from 1.11.0 to 1.13.0. - [Release notes](https://github.com/onnx/onnx/releases) - [Changelog](https://github.com/onnx/onnx/blob/main/docs/Changelog.md) - [Commits](https://github.com/onnx/onnx/compare/v1.11.0...v1.13.0) --- updated-dependencies: - dependency-name: onnx dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> * adapt * finish * Update examples/research_projects/decision_transformer/requirements.txt * up * add tests * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * fix test --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
bc44e947f3
commit
90cddfa824
@@ -667,6 +667,15 @@ def _load_state_dict_into_meta_model(
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||
@@ -1567,6 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -1604,6 +1614,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
@@ -1675,6 +1687,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Shard the model if it is too big.
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
@@ -1701,10 +1715,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
save_function(shard, os.path.join(save_directory, shard_file))
|
||||
|
||||
if index is None:
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||
path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant))
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
else:
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, save_index_file)
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
@@ -1931,6 +1946,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_tf` or `from_flax`.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
@@ -2017,6 +2035,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
@@ -2132,42 +2151,57 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Load from a Flax checkpoint in priority if from_flax
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||
)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
)
|
||||
):
|
||||
# Load from a sharded safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||
):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)):
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||
)
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||
):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
|
||||
)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||
"weights."
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
|
||||
" `from_tf=True` to load this model from those weights."
|
||||
)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||
"weights."
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
|
||||
" to load this model from those weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
||||
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME},"
|
||||
f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
|
||||
f" {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
@@ -2190,9 +2224,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
elif is_safetensors_available():
|
||||
filename = SAFE_WEIGHTS_NAME
|
||||
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
@@ -2213,23 +2247,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||
# result when internet is up, the repo and revision exist, but the file does not.
|
||||
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
||||
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
pretrained_model_name_or_path,
|
||||
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
|
||||
**cached_file_kwargs,
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
else:
|
||||
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||
filename = WEIGHTS_NAME
|
||||
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs
|
||||
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
pretrained_model_name_or_path,
|
||||
_add_variant(WEIGHTS_INDEX_NAME, variant),
|
||||
**cached_file_kwargs,
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
@@ -2244,19 +2282,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {WEIGHTS_NAME} but there is a file for TensorFlow weights. Use `from_tf=True` to"
|
||||
" load this model from those weights."
|
||||
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
|
||||
" Use `from_tf=True` to load this model from those weights."
|
||||
)
|
||||
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {WEIGHTS_NAME} but there is a file for Flax weights. Use `from_flax=True` to load"
|
||||
" this model from those weights."
|
||||
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
|
||||
" `from_flax=True` to load this model from those weights."
|
||||
)
|
||||
elif variant is not None and has_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
||||
f" {variant}. Use `variant=None` to load this model from those weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
||||
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
||||
f" {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||
@@ -2268,8 +2315,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
||||
f" {FLAX_WEIGHTS_NAME}."
|
||||
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
|
||||
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
if is_local:
|
||||
|
||||
@@ -2958,6 +2958,138 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_local(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, variant="v2")
|
||||
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||
|
||||
weights_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_local_sharded(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
|
||||
|
||||
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||
|
||||
for i in range(1, 6):
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_variant_local_safe(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)
|
||||
|
||||
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])
|
||||
|
||||
weights_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_variant_local_sharded_safe(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)
|
||||
|
||||
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
for i in range(1, 6):
|
||||
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
|
||||
model = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_checkpoint_variant_hub_sharded(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
|
||||
)
|
||||
model = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir, variant="v2"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_variant_hub_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
|
||||
model = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_variant_hub_sharded_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
|
||||
)
|
||||
model = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_accelerate
|
||||
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
||||
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
||||
|
||||
Reference in New Issue
Block a user