Compare commits

...

4 Commits

Author SHA1 Message Date
Sylvain Gugger
5e3b19a805 Patch release: v4.27.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-03-23 14:04:40 -04:00
Sylvain Gugger
62d9baa53c Enforce max_memory for device_map strategies (#22311)
Enforce  for device_map strategies
2023-03-23 14:04:10 -04:00
Sylvain Gugger
68287689f2 Patch release: v4.27.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2023-03-20 12:02:35 -04:00
Sylvain Gugger
1e39734c4b Fix balanced and auto device_map (#22271) 2023-03-20 12:01:08 -04:00
3 changed files with 5 additions and 3 deletions

View File

@@ -418,7 +418,7 @@ install_requires = [
setup(
name="transformers",
version="4.27.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.27.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.27.1"
__version__ = "4.27.3"
from typing import TYPE_CHECKING

View File

@@ -2563,7 +2563,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory}
kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0:
@@ -2576,8 +2576,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**kwargs,
)
kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)