Detect and use device context manager or global device in from_pretrained (#37216)
* Update modeling_utils.py * improve * Update modeling_utils.py * Update test_modeling_common.py * Update test_modeling_timm_backbone.py * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * CIs
This commit is contained in:
@@ -287,6 +287,21 @@ def restore_default_torch_dtype(func):
|
|||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_context_manager_or_global_device():
|
||||||
|
"""
|
||||||
|
Test if a device context manager is currently in use, or if it is not the case, check if the default device
|
||||||
|
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
|
||||||
|
"""
|
||||||
|
device_in_context = torch.tensor([]).device
|
||||||
|
default_device = torch.get_default_device()
|
||||||
|
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
|
||||||
|
if device_in_context == default_device:
|
||||||
|
if default_device != torch.device("cpu"):
|
||||||
|
return default_device
|
||||||
|
return None
|
||||||
|
return device_in_context
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||||
try:
|
try:
|
||||||
return next(parameter.parameters()).device
|
return next(parameter.parameters()).device
|
||||||
@@ -4153,6 +4168,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
else:
|
else:
|
||||||
_adapter_model_path = None
|
_adapter_model_path = None
|
||||||
|
|
||||||
|
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
||||||
|
if device_map is None:
|
||||||
|
device_in_context = get_torch_context_manager_or_global_device()
|
||||||
|
if device_in_context == torch.device("meta"):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
|
||||||
|
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
|
||||||
|
"or global device with `from_config`, or `ModelClass(config)`"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device_map = device_in_context
|
||||||
|
|
||||||
# change device_map into a map if we passed an int, a str or a torch.device
|
# change device_map into a map if we passed an int, a str or a torch.device
|
||||||
if isinstance(device_map, torch.device):
|
if isinstance(device_map, torch.device):
|
||||||
device_map = {"": device_map}
|
device_map = {"": device_map}
|
||||||
@@ -4177,7 +4205,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||||
if not is_accelerate_available():
|
if not is_accelerate_available():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
|
(
|
||||||
|
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
|
||||||
|
"requires `accelerate`. You can install it with `pip install accelerate`"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||||
|
|||||||
@@ -168,6 +168,18 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
|
|||||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
|
||||||
|
def test_can_load_with_device_context_manager(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
|
||||||
|
def test_can_load_with_global_device_set(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
|
||||||
|
def test_cannot_load_with_meta_device_context_manager(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -4454,6 +4454,73 @@ class ModelTesterMixin:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
|
def test_can_load_with_device_context_manager(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
# Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda"
|
||||||
|
device = torch.device(torch_device, index=0)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
|
||||||
|
# is not supported for e.g. dpt_hybrid)
|
||||||
|
model = model_class(copy.deepcopy(config))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
with device:
|
||||||
|
new_model = model_class.from_pretrained(tmpdirname)
|
||||||
|
unique_devices = {param.device for param in new_model.parameters()} | {
|
||||||
|
buffer.device for buffer in new_model.buffers()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
|
def test_can_load_with_global_device_set(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
# Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda"
|
||||||
|
device = torch.device(torch_device, index=0)
|
||||||
|
default_device = torch.get_default_device()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
|
||||||
|
# is not supported for e.g. dpt_hybrid)
|
||||||
|
model = model_class(copy.deepcopy(config))
|
||||||
|
|
||||||
|
# set a global gpu device
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
new_model = model_class.from_pretrained(tmpdirname)
|
||||||
|
unique_devices = {param.device for param in new_model.parameters()} | {
|
||||||
|
buffer.device for buffer in new_model.buffers()
|
||||||
|
}
|
||||||
|
|
||||||
|
# set back the correct device
|
||||||
|
torch.set_default_device(default_device)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cannot_load_with_meta_device_context_manager(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
|
||||||
|
# is not supported for e.g. dpt_hybrid)
|
||||||
|
model = model_class(copy.deepcopy(config))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# This should raise an error with meta device
|
||||||
|
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
|
||||||
|
with torch.device("meta"):
|
||||||
|
_ = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user