[modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True) + tests (#16657)
* add low_cpu_mem_usage tests * wip: revamping * wip * install /usr/bin/time * wip * cleanup * cleanup * cleanup * cleanup * cleanup * fix assert * put the wrapper back * cleanup; switch to bert-base-cased * Trigger CI * Trigger CI
This commit is contained in:
@@ -217,7 +217,7 @@ jobs:
|
|||||||
keys:
|
keys:
|
||||||
- v0.4-torch-{{ checksum "setup.py" }}
|
- v0.4-torch-{{ checksum "setup.py" }}
|
||||||
- v0.4-{{ checksum "setup.py" }}
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
|
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time
|
||||||
- run: pip install --upgrade pip
|
- run: pip install --upgrade pip
|
||||||
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
|
|||||||
@@ -400,6 +400,95 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
|||||||
return error_msgs
|
return error_msgs
|
||||||
|
|
||||||
|
|
||||||
|
def find_submodule_and_param_name(model, long_key, start_prefix):
|
||||||
|
"""
|
||||||
|
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
|
||||||
|
from the start of the key
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(start_prefix) > 0 and long_key.startswith(start_prefix):
|
||||||
|
long_key = ".".join(long_key.split(".")[1:])
|
||||||
|
|
||||||
|
split_key = long_key.split(".")
|
||||||
|
submodule = model
|
||||||
|
while len(split_key) > 1:
|
||||||
|
if hasattr(submodule, split_key[0]):
|
||||||
|
submodule = getattr(submodule, split_key[0])
|
||||||
|
del split_key[0]
|
||||||
|
else:
|
||||||
|
submodule = None
|
||||||
|
break
|
||||||
|
if submodule == model:
|
||||||
|
submodule = None
|
||||||
|
return submodule, split_key[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||||
|
"""
|
||||||
|
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
|
||||||
|
|
||||||
|
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||||
|
`bert.pooler.dense.weight`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# meta device was added in pt=1.9
|
||||||
|
require_version_core("torch>=1.9")
|
||||||
|
|
||||||
|
# dematerialize param storage for keys that are going to be replaced by state_dict, by
|
||||||
|
# putting those on the meta device
|
||||||
|
for k in loaded_state_dict_keys:
|
||||||
|
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
|
||||||
|
if submodule is not None:
|
||||||
|
# selectively switch to the meta device only those params/buffers that will
|
||||||
|
# be next replaced from state_dict. This a complex way to do p.to_("meta")
|
||||||
|
# since we have no in-place to_ for tensors.
|
||||||
|
new_val = getattr(submodule, param_name)
|
||||||
|
if isinstance(new_val, torch.nn.Parameter):
|
||||||
|
# isinstance returns False for Params on meta device, so switch after the check
|
||||||
|
new_val = torch.nn.Parameter(new_val.to("meta"))
|
||||||
|
else:
|
||||||
|
new_val = new_val.to("meta")
|
||||||
|
setattr(submodule, param_name, new_val)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix):
|
||||||
|
"""
|
||||||
|
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||||
|
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
||||||
|
params back to the normal device, but only for `loaded_state_dict_keys`.
|
||||||
|
|
||||||
|
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||||
|
`bert.pooler.dense.weight`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||||
|
# - deepspeed zero 3 support
|
||||||
|
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||||
|
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
||||||
|
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
|
||||||
|
# they won't get loaded.
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")
|
||||||
|
|
||||||
|
error_msgs = []
|
||||||
|
|
||||||
|
# materialize state_dict entries one by one on CPU
|
||||||
|
for k in loaded_state_dict_keys:
|
||||||
|
if k in state_dict:
|
||||||
|
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
|
||||||
|
if submodule is not None:
|
||||||
|
param_dtype = getattr(submodule, param_name).dtype
|
||||||
|
new_val = state_dict[k].to(param_dtype)
|
||||||
|
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
||||||
|
new_val = torch.nn.Parameter(new_val)
|
||||||
|
setattr(submodule, param_name, new_val)
|
||||||
|
|
||||||
|
return error_msgs
|
||||||
|
|
||||||
|
|
||||||
class ModuleUtilsMixin:
|
class ModuleUtilsMixin:
|
||||||
"""
|
"""
|
||||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||||
@@ -1529,7 +1618,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
|
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
|
||||||
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
|
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
|
||||||
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
|
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
|
||||||
```"""
|
```
|
||||||
|
|
||||||
|
* `low_cpu_mem_usage` algorithm:
|
||||||
|
|
||||||
|
This is an experimental function that loads the model using ~1x model size CPU memory
|
||||||
|
|
||||||
|
Here is how it works:
|
||||||
|
|
||||||
|
1. save which state_dict keys we have
|
||||||
|
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
|
||||||
|
3. after the model has been instantiated switch to the meta device all params/buffers that
|
||||||
|
are going to be replaced from the loaded state_dict
|
||||||
|
4. load state_dict 2nd time
|
||||||
|
5. replace the params/buffers from the state_dict
|
||||||
|
|
||||||
|
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
|
||||||
|
|
||||||
|
"""
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
state_dict = kwargs.pop("state_dict", None)
|
state_dict = kwargs.pop("state_dict", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
@@ -1778,6 +1884,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if not is_sharded and state_dict is None:
|
if not is_sharded and state_dict is None:
|
||||||
# Time to load the checkpoint
|
# Time to load the checkpoint
|
||||||
state_dict = load_state_dict(resolved_archive_file)
|
state_dict = load_state_dict(resolved_archive_file)
|
||||||
|
|
||||||
# set dtype to instantiate the model under:
|
# set dtype to instantiate the model under:
|
||||||
# 1. If torch_dtype is not None, we use that dtype
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||||
@@ -1801,13 +1908,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
# save the keys
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
else:
|
else:
|
||||||
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
||||||
del state_dict # free CPU memory - will reload again later
|
if low_cpu_mem_usage:
|
||||||
|
state_dict = None
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
@@ -1825,11 +1931,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
with no_init_weights(_enable=_fast_init):
|
with no_init_weights(_enable=_fast_init):
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
if from_pt:
|
|
||||||
# restore default dtype
|
|
||||||
if dtype_orig is not None:
|
|
||||||
torch.set_default_dtype(dtype_orig)
|
|
||||||
|
|
||||||
if from_tf:
|
if from_tf:
|
||||||
if resolved_archive_file.endswith(".index"):
|
if resolved_archive_file.endswith(".index"):
|
||||||
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
||||||
@@ -1859,17 +1960,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
raise
|
raise
|
||||||
elif from_pt:
|
elif from_pt:
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
# restore default dtype
|
||||||
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
if dtype_orig is not None:
|
||||||
else:
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
|
loaded_state_dict_keys, # XXX: rename?
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
sharded_metadata=sharded_metadata,
|
sharded_metadata=sharded_metadata,
|
||||||
_fast_init=_fast_init,
|
_fast_init=_fast_init,
|
||||||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
@@ -1894,16 +1998,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
cls,
|
cls,
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
|
loaded_keys,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
ignore_mismatched_sizes=False,
|
ignore_mismatched_sizes=False,
|
||||||
sharded_metadata=None,
|
sharded_metadata=None,
|
||||||
_fast_init=True,
|
_fast_init=True,
|
||||||
|
low_cpu_mem_usage=False,
|
||||||
):
|
):
|
||||||
# Retrieve missing & unexpected_keys
|
# Retrieve missing & unexpected_keys
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
|
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
def _fix_key(key):
|
def _fix_key(key):
|
||||||
@@ -1994,9 +2099,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||||
)
|
)
|
||||||
del state_dict[checkpoint_key]
|
del state_dict[checkpoint_key]
|
||||||
|
|
||||||
return mismatched_keys
|
return mismatched_keys
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
model_state_dict = None # free references to model's params to allow memory freeing
|
||||||
|
_move_model_to_meta(model, loaded_keys, start_prefix)
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
# Whole checkpoint
|
# Whole checkpoint
|
||||||
mismatched_keys = _find_mismatched_keys(
|
mismatched_keys = _find_mismatched_keys(
|
||||||
@@ -2009,7 +2117,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||||
else:
|
else:
|
||||||
# Sharded checkpoint
|
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||||
|
|
||||||
# This should always be a list but, just to be sure.
|
# This should always be a list but, just to be sure.
|
||||||
if not isinstance(resolved_archive_file, list):
|
if not isinstance(resolved_archive_file, list):
|
||||||
resolved_archive_file = [resolved_archive_file]
|
resolved_archive_file = [resolved_archive_file]
|
||||||
@@ -2018,6 +2127,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
for shard_file in resolved_archive_file:
|
for shard_file in resolved_archive_file:
|
||||||
state_dict = load_state_dict(shard_file)
|
state_dict = load_state_dict(shard_file)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
|
||||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
# matching the weights in the model.
|
# matching the weights in the model.
|
||||||
mismatched_keys += _find_mismatched_keys(
|
mismatched_keys += _find_mismatched_keys(
|
||||||
@@ -2028,6 +2141,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
ignore_mismatched_sizes,
|
ignore_mismatched_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
error_msgs += _load_state_dict_into_meta_model(
|
||||||
|
model_to_load, state_dict, loaded_keys, start_prefix
|
||||||
|
)
|
||||||
|
else:
|
||||||
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||||
|
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
@@ -2093,13 +2212,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return retrieved_modules
|
return retrieved_modules
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file):
|
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
|
||||||
"""
|
"""
|
||||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||||
|
|
||||||
Before it gets called we do:
|
Before you call it do:
|
||||||
|
|
||||||
1. save which state_dict keys we have
|
1. save which state_dict keys are available
|
||||||
2. drop state_dict before model is created, since the latter takes 1x model size memory
|
2. drop state_dict before model is created, since the latter takes 1x model size memory
|
||||||
|
|
||||||
Here then we continue:
|
Here then we continue:
|
||||||
@@ -2110,58 +2229,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
|
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
|
||||||
"""
|
"""
|
||||||
require_version_core("torch>=1.9")
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
|
|
||||||
|
|
||||||
# a helper util to find the last sub-module and the param/buffer name
|
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
||||||
def find_submodule_and_param_name(model, long_key):
|
state_dict = load_state_dict(resolved_archive_file)
|
||||||
split_key = long_key.split(".")
|
error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
|
||||||
submodule = model
|
return error_msgs
|
||||||
while len(split_key) > 1:
|
|
||||||
if hasattr(submodule, split_key[0]):
|
|
||||||
submodule = getattr(submodule, split_key[0])
|
|
||||||
del split_key[0]
|
|
||||||
else:
|
|
||||||
submodule = None
|
|
||||||
break
|
|
||||||
return submodule, split_key[0]
|
|
||||||
|
|
||||||
# dematerialize param storage for keys that are going to be replaced by state_dict, by
|
|
||||||
# putting those on the meta device
|
|
||||||
for k in loaded_state_dict_keys:
|
|
||||||
submodule, param_name = find_submodule_and_param_name(model, k)
|
|
||||||
if submodule is not None:
|
|
||||||
# selectively switch to the meta device only those params/buffers that will
|
|
||||||
# be next replaced from state_dict. This a complex way to do p.to_("meta")
|
|
||||||
# since we have no in-place to_ for tensors.
|
|
||||||
new_val = getattr(submodule, param_name)
|
|
||||||
if isinstance(new_val, torch.nn.Parameter):
|
|
||||||
# isinstance returns False for Params on meta device, so switch after the check
|
|
||||||
new_val = torch.nn.Parameter(new_val.to("meta"))
|
|
||||||
else:
|
|
||||||
new_val = new_val.to("meta")
|
|
||||||
setattr(submodule, param_name, new_val)
|
|
||||||
|
|
||||||
# only now can load state_dict(s)
|
|
||||||
if not isinstance(resolved_archive_file, list):
|
|
||||||
resolved_archive_file = [resolved_archive_file]
|
|
||||||
|
|
||||||
for archive_file in resolved_archive_file:
|
|
||||||
state_dict = torch.load(archive_file, map_location="cpu")
|
|
||||||
|
|
||||||
# materialize state_dict entries one by one on CPU
|
|
||||||
for k in loaded_state_dict_keys:
|
|
||||||
if k in state_dict:
|
|
||||||
submodule, param_name = find_submodule_and_param_name(model, k)
|
|
||||||
if submodule is not None:
|
|
||||||
param_dtype = getattr(submodule, param_name).dtype
|
|
||||||
new_val = state_dict[k].to(param_dtype)
|
|
||||||
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
|
||||||
new_val = torch.nn.Parameter(new_val)
|
|
||||||
setattr(submodule, param_name, new_val)
|
|
||||||
|
|
||||||
del state_dict
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shlex
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -667,6 +668,20 @@ def require_librosa(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_exists(cmd):
|
||||||
|
return shutil.which(cmd) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def require_usr_bin_time(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires `/usr/bin/time`
|
||||||
|
"""
|
||||||
|
if not cmd_exists("/usr/bin/time"):
|
||||||
|
return unittest.skip("test requires /usr/bin/time")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_count():
|
def get_gpu_count():
|
||||||
"""
|
"""
|
||||||
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
||||||
@@ -1178,6 +1193,39 @@ class TestCasePlus(unittest.TestCase):
|
|||||||
|
|
||||||
return tmp_dir
|
return tmp_dir
|
||||||
|
|
||||||
|
def python_one_liner_max_rss(self, one_liner_str):
|
||||||
|
"""
|
||||||
|
Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
|
||||||
|
program.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
one_liner_str (`string`):
|
||||||
|
a python one liner code that gets passed to `python -c`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
this helper needs `/usr/bin/time` to be installed (`apt install time`)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")'
|
||||||
|
max_rss = self.python_one_liner_max_rss(one_liner_str)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not cmd_exists("/usr/bin/time"):
|
||||||
|
raise ValueError("/usr/bin/time is required, install with `apt install time`")
|
||||||
|
|
||||||
|
cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
|
||||||
|
with CaptureStd() as cs:
|
||||||
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
# returned data is in KB so convert to bytes
|
||||||
|
max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
|
||||||
|
return max_rss
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
||||||
# get_auto_remove_tmp_dir feature: remove registered temp dirs
|
# get_auto_remove_tmp_dir feature: remove registered temp dirs
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from transformers.testing_utils import (
|
|||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
|
require_usr_bin_time,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -2489,6 +2490,56 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
||||||
|
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
||||||
|
# sharded models
|
||||||
|
|
||||||
|
mnames = [
|
||||||
|
"hf-internal-testing/tiny-random-bert-sharded",
|
||||||
|
"hf-internal-testing/tiny-random-bert",
|
||||||
|
]
|
||||||
|
for mname in mnames:
|
||||||
|
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
|
||||||
|
|
||||||
|
@require_usr_bin_time
|
||||||
|
def test_from_pretrained_low_cpu_mem_usage_measured(self):
|
||||||
|
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
|
||||||
|
|
||||||
|
mname = "bert-base-cased"
|
||||||
|
|
||||||
|
preamble = "from transformers import AutoModel"
|
||||||
|
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
|
||||||
|
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
|
||||||
|
# print(f"{max_rss_normal=}")
|
||||||
|
|
||||||
|
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
|
||||||
|
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
|
||||||
|
# print(f"{max_rss_low_mem=}")
|
||||||
|
|
||||||
|
diff_bytes = max_rss_normal - max_rss_low_mem
|
||||||
|
diff_percent = diff_bytes / max_rss_low_mem
|
||||||
|
# print(f"{diff_bytes=}, {diff_percent=}")
|
||||||
|
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
|
||||||
|
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
|
||||||
|
# it's at least 15% less cpu memory consumed
|
||||||
|
|
||||||
|
self.assertGreater(
|
||||||
|
diff_percent,
|
||||||
|
0.15,
|
||||||
|
"should use less CPU memory for low_cpu_mem_usage=True, "
|
||||||
|
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# if you want to compare things manually, let's first look at the size of the model in bytes
|
||||||
|
# model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False)
|
||||||
|
# total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
||||||
|
# total_bytes = total_numel * 4 # 420MB
|
||||||
|
# Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent.
|
||||||
|
# The easiest way to test this is to switch the model and torch.load to do all the work on
|
||||||
|
# gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add
|
||||||
|
# functionality to load models directly on gpu, this test can be rewritten to use torch's
|
||||||
|
# cuda memory tracking and then we should be able to do a much more precise test.
|
||||||
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
|
|||||||
Reference in New Issue
Block a user