Use Accelerate in from_pretrained for big model inference (#17341)
* Initial work * More or less finished with first draft * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Fix randomly initialized weights * Update src/transformers/modeling_utils.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Address review comments * Rename DeepSpeed folder to temporarily fix the test issue? * Revert to try if Accelerate fix works * Use latest Accelerate release * Quality and fixes * Style * Quality * Add doc * Test + fix * More blocks Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -38,6 +38,75 @@ for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch mode
|
||||
|
||||
<a id='from_pretrained-torch-dtype'></a>
|
||||
|
||||
### Large model loading
|
||||
|
||||
In Transformers 4.20.0, the [`~PreTrainedModel.from_pretrained`] method has been reworked to accommodate large models using [Accelerate](https://huggingface.co/docs/accelerate/big_modeling). This requires Accelerate >= 0.9.0 and PyTorch >= 1.9.0. Instead of creating the full model, then loading the pretrained weights inside it (which takes twice the size of the model in RAM, one for the randomly initialized model, one for the weights), there is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded.
|
||||
|
||||
This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). This way the maximum RAM used is the full size of the model only.
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", low_cpu_mem_usage=True)
|
||||
```
|
||||
|
||||
Moreover, you can directly place the model on different devices if it doesn't fully fit in RAM (only works for inference for now). With `device_map="auto"`, Accelerate will determine where to put each layer to maximize the use of your fastest devices (GPUs) and offload the rest on the CPU, or even the hard drive if you don't have enough GPU RAM (or CPU RAM). Even if the model is split across several devices, it will run as you would normally expect.
|
||||
|
||||
When passing a `device_map`, `low_cpu_mem_usage` is automatically set to `True`, so you don't need to specify it:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")
|
||||
```
|
||||
|
||||
You can inspect how the model was split across devices by looking at its `hf_device_map` attribute:
|
||||
|
||||
```py
|
||||
t0pp.hf_device_map
|
||||
```
|
||||
|
||||
```python out
|
||||
{'shared': 0,
|
||||
'decoder.embed_tokens': 0,
|
||||
'encoder': 0,
|
||||
'decoder.block.0': 0,
|
||||
'decoder.block.1': 1,
|
||||
'decoder.block.2': 1,
|
||||
'decoder.block.3': 1,
|
||||
'decoder.block.4': 1,
|
||||
'decoder.block.5': 1,
|
||||
'decoder.block.6': 1,
|
||||
'decoder.block.7': 1,
|
||||
'decoder.block.8': 1,
|
||||
'decoder.block.9': 1,
|
||||
'decoder.block.10': 1,
|
||||
'decoder.block.11': 1,
|
||||
'decoder.block.12': 1,
|
||||
'decoder.block.13': 1,
|
||||
'decoder.block.14': 1,
|
||||
'decoder.block.15': 1,
|
||||
'decoder.block.16': 1,
|
||||
'decoder.block.17': 1,
|
||||
'decoder.block.18': 1,
|
||||
'decoder.block.19': 1,
|
||||
'decoder.block.20': 1,
|
||||
'decoder.block.21': 1,
|
||||
'decoder.block.22': 'cpu',
|
||||
'decoder.block.23': 'cpu',
|
||||
'decoder.final_layer_norm': 'cpu',
|
||||
'decoder.dropout': 'cpu',
|
||||
'lm_head': 'cpu'}
|
||||
```
|
||||
|
||||
You can also write your own device map following the same format (a dictionary layer name to device). It should map all parameters of the model to a given device, but you don't have to detail where all the submosules of one layer go if that layer is entirely on the same device. For instance, the following device map would work properly for T0pp (as long as you have the GPU memory):
|
||||
|
||||
```python
|
||||
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
|
||||
```
|
||||
|
||||
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`).
|
||||
|
||||
### Model Instantiation dtype
|
||||
|
||||
Under Pytorch a model normally gets instantiated with `torch.float32` format. This can be an issue if one tries to
|
||||
|
||||
2
setup.py
2
setup.py
@@ -97,7 +97,7 @@ if stale_egg_info.exists():
|
||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"accelerate>=0.7.1",
|
||||
"accelerate>=0.9.0",
|
||||
"black~=22.0,>=22.3",
|
||||
"codecarbon==1.2.0",
|
||||
"cookiecutter==1.7.3",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.7.1",
|
||||
"accelerate": "accelerate>=0.9.0",
|
||||
"black": "black~=22.0,>=22.3",
|
||||
"codecarbon": "codecarbon==1.2.0",
|
||||
"cookiecutter": "cookiecutter==1.7.3",
|
||||
|
||||
@@ -54,6 +54,7 @@ from .utils import (
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
@@ -62,6 +63,7 @@ from .utils import (
|
||||
cached_path,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_accelerate_available,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
@@ -70,6 +72,15 @@ from .utils import (
|
||||
from .utils.versions import require_version_core
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
||||
from accelerate.utils import (
|
||||
load_offloaded_weights,
|
||||
offload_weight,
|
||||
save_offload_index,
|
||||
set_module_tensor_to_device,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -514,7 +525,19 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
|
||||
def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix):
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # left for now but could be removed, see below
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=None,
|
||||
offload_folder=None,
|
||||
offload_index=None,
|
||||
state_dict_folder=None,
|
||||
state_dict_index=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
||||
@@ -532,23 +555,55 @@ def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys,
|
||||
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
|
||||
# they won't get loaded.
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")
|
||||
|
||||
error_msgs = []
|
||||
|
||||
# materialize state_dict entries one by one on CPU
|
||||
for k in loaded_state_dict_keys:
|
||||
if k in state_dict:
|
||||
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
|
||||
if submodule is not None:
|
||||
param_dtype = getattr(submodule, param_name).dtype
|
||||
new_val = state_dict[k].to(param_dtype)
|
||||
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
||||
new_val = torch.nn.Parameter(new_val)
|
||||
setattr(submodule, param_name, new_val)
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
new_key = key.replace("gamma", "weight")
|
||||
if "beta" in key:
|
||||
new_key = key.replace("beta", "bias")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
return error_msgs
|
||||
for param_name, param in state_dict.items():
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
|
||||
print(param_name)
|
||||
continue
|
||||
|
||||
if param_name.startswith(start_prefix):
|
||||
param_name = param_name[len(start_prefix) :]
|
||||
|
||||
module_name = param_name
|
||||
# We convert floating dtypes to the `dtype` passed.
|
||||
if dtype is not None and not str(param.dtype).startswith("torch.int"):
|
||||
param = param.to(dtype)
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
# find next higher level module that is defined in device_map:
|
||||
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
|
||||
while len(module_name) > 0 and module_name not in device_map:
|
||||
module_name = ".".join(module_name.split(".")[:-1])
|
||||
if module_name == "" and "" not in device_map:
|
||||
# TODO: group all errors and raise at the end.
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
param_device = device_map[module_name]
|
||||
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
if param_device == "disk":
|
||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
@@ -870,6 +925,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
_auto_class = None
|
||||
_no_split_modules = None
|
||||
|
||||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
||||
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
||||
@@ -1664,12 +1720,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
</Tip>
|
||||
|
||||
low_cpu_mem_usage(`bool`, *optional*, defaults to `False`):
|
||||
> Parameters for big model inference
|
||||
|
||||
low_cpu_mem_usage(`bool`, *optional*):
|
||||
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
This is an experimental feature and a subject to change at any moment.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
|
||||
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
@@ -1750,7 +1820,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_fast_init = kwargs.pop("_fast_init", True)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
|
||||
if device_map is not None:
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
elif not low_cpu_mem_usage:
|
||||
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# low_cpu_mem_usage requires PyTorch >= 1.9 to have the meta device.
|
||||
require_version_core("torch>=1.9")
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError(
|
||||
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
|
||||
)
|
||||
elif not is_accelerate_available():
|
||||
raise ImportError(
|
||||
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
|
||||
)
|
||||
|
||||
from_pt = not (from_tf | from_flax)
|
||||
|
||||
@@ -1845,10 +1937,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
filename = WEIGHTS_NAME
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -2013,18 +2102,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||
# and memory copying it on CPU or each GPU first
|
||||
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
|
||||
with no_init_weights(_enable=_fast_init):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
else:
|
||||
with no_init_weights(_enable=_fast_init):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
|
||||
elif low_cpu_mem_usage:
|
||||
init_contexts.append(init_empty_weights())
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if device_map == "auto":
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
|
||||
no_split_modules = model._no_split_modules
|
||||
device_map = infer_auto_device_map(model, no_split_module_classes=no_split_modules, dtype=torch_dtype)
|
||||
|
||||
if from_tf:
|
||||
if resolved_archive_file.endswith(".index"):
|
||||
@@ -2071,6 +2166,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
sharded_metadata=sharded_metadata,
|
||||
_fast_init=_fast_init,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
@@ -2079,6 +2178,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -2102,6 +2205,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
sharded_metadata=None,
|
||||
_fast_init=True,
|
||||
low_cpu_mem_usage=False,
|
||||
device_map=None,
|
||||
offload_folder=None,
|
||||
offload_state_dict=False,
|
||||
dtype=None,
|
||||
):
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = model.state_dict()
|
||||
@@ -2149,8 +2256,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
# retrieve weights on meta device and put them back on CPU.
|
||||
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
|
||||
if low_cpu_mem_usage:
|
||||
for key in missing_keys:
|
||||
if key.startswith(prefix):
|
||||
key = ".".join(key.split(".")[1:])
|
||||
param = model_state_dict[key]
|
||||
if param.device == torch.device("meta"):
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
|
||||
|
||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||
if _fast_init:
|
||||
# retrieve unintialized modules and initialize
|
||||
uninitialized_modules = model.retrieve_modules_from_names(
|
||||
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
|
||||
)
|
||||
@@ -2169,6 +2286,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
||||
"properly saved?"
|
||||
)
|
||||
if device_map is not None:
|
||||
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
@@ -2199,10 +2318,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
model_state_dict = None # free references to model's params to allow memory freeing
|
||||
_move_model_to_meta(model, loaded_keys, start_prefix)
|
||||
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
@@ -2223,12 +2338,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
error_msgs = []
|
||||
mismatched_keys = []
|
||||
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
||||
if offload_state_dict:
|
||||
state_dict_folder = tempfile.mkdtemp()
|
||||
state_dict_index = {}
|
||||
else:
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
|
||||
for shard_file in resolved_archive_file:
|
||||
state_dict = load_state_dict(shard_file)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
@@ -2241,9 +2361,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
error_msgs += _load_state_dict_into_meta_model(
|
||||
model_to_load, state_dict, loaded_keys, start_prefix
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
loaded_keys,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
state_dict_folder=state_dict_folder,
|
||||
state_dict_index=state_dict_index,
|
||||
dtype=dtype,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||
|
||||
@@ -2251,6 +2382,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
|
||||
if offload_state_dict:
|
||||
# Load back temporarily offloaded state dict
|
||||
load_offloaded_weights(model, state_dict_index, state_dict_folder)
|
||||
shutil.rmtree(state_dict_folder)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
if "size mismatch" in error_msg:
|
||||
|
||||
@@ -448,6 +448,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPT2Block"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -358,6 +358,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
|
||||
load_tf_weights = load_tf_weights_in_gpt_neo
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -334,6 +334,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTJBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -747,6 +747,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["T5Block"]
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
||||
@@ -94,6 +94,8 @@ if is_torch_available():
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AdaptiveEmbedding,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
PreTrainedModel,
|
||||
@@ -2595,6 +2597,22 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# functionality to load models directly on gpu, this test can be rewritten to use torch's
|
||||
# cuda memory tracking and then we should be able to do a much more precise test.
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@slow
|
||||
def test_model_parallelism_gpt2(self):
|
||||
device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
|
||||
for i in range(12):
|
||||
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
inputs = tokenizer("Hello, my name is", return_tensors="pt")
|
||||
output = model.generate(inputs["input_ids"].to(0))
|
||||
|
||||
text_output = tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
||||
Reference in New Issue
Block a user