From c8e0e603de9b3d49161a15fe6e8ea84badfb5d02 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 15 Apr 2025 09:59:20 +0200 Subject: [PATCH] Detect and use device context manager or global device in `from_pretrained` (#37216) * Update modeling_utils.py * improve * Update modeling_utils.py * Update test_modeling_common.py * Update test_modeling_timm_backbone.py * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * CIs --- src/transformers/modeling_utils.py | 33 ++++++++- .../test_modeling_timm_backbone.py | 12 ++++ tests/test_modeling_common.py | 67 +++++++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5a8ebc6847..7c7209db51 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -287,6 +287,21 @@ def restore_default_torch_dtype(func): return _wrapper +def get_torch_context_manager_or_global_device(): + """ + Test if a device context manager is currently in use, or if it is not the case, check if the default device + is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided. + """ + device_in_context = torch.tensor([]).device + default_device = torch.get_default_device() + # This case means no context manager was used -> we still check if the default that was potentially set is not cpu + if device_in_context == default_device: + if default_device != torch.device("cpu"): + return default_device + return None + return device_in_context + + def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device @@ -4153,6 +4168,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi else: _adapter_model_path = None + # Potentially detect context manager or global device, and use it (only if no device_map was provided) + if device_map is None: + device_in_context = get_torch_context_manager_or_global_device() + if device_in_context == torch.device("meta"): + raise ValueError( + ( + "`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` " + "as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager " + "or global device with `from_config`, or `ModelClass(config)`" + ) + ) + device_map = device_in_context + # change device_map into a map if we passed an int, a str or a torch.device if isinstance(device_map, torch.device): device_map = {"": device_map} @@ -4177,7 +4205,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.") if not is_accelerate_available(): raise ValueError( - "Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`" + ( + "Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` " + "requires `accelerate`. You can install it with `pip install accelerate`" + ) ) # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index af48bc76dc..582bdab0b5 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -168,6 +168,18 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste def test_save_load_low_cpu_mem_usage_no_safetensors(self): pass + @unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support") + def test_can_load_with_device_context_manager(self): + pass + + @unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support") + def test_can_load_with_global_device_set(self): + pass + + @unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support") + def test_cannot_load_with_meta_device_context_manager(self): + pass + @unittest.skip(reason="model weights aren't tied in TimmBackbone.") def test_tie_model_weights(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7aa8134e5d..5323760fe0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4454,6 +4454,73 @@ class ModelTesterMixin: ), ) + @require_torch_accelerator + def test_can_load_with_device_context_manager(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda" + device = torch.device(torch_device, index=0) + for model_class in self.all_model_classes: + # Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which + # is not supported for e.g. dpt_hybrid) + model = model_class(copy.deepcopy(config)) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + with device: + new_model = model_class.from_pretrained(tmpdirname) + unique_devices = {param.device for param in new_model.parameters()} | { + buffer.device for buffer in new_model.buffers() + } + + self.assertEqual( + unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}." + ) + + @require_torch_accelerator + def test_can_load_with_global_device_set(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda" + device = torch.device(torch_device, index=0) + default_device = torch.get_default_device() + for model_class in self.all_model_classes: + # Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which + # is not supported for e.g. dpt_hybrid) + model = model_class(copy.deepcopy(config)) + + # set a global gpu device + torch.set_default_device(device) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + new_model = model_class.from_pretrained(tmpdirname) + unique_devices = {param.device for param in new_model.parameters()} | { + buffer.device for buffer in new_model.buffers() + } + + # set back the correct device + torch.set_default_device(default_device) + + self.assertEqual( + unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}." + ) + + def test_cannot_load_with_meta_device_context_manager(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + # Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which + # is not supported for e.g. dpt_hybrid) + model = model_class(copy.deepcopy(config)) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # This should raise an error with meta device + with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"): + with torch.device("meta"): + _ = model_class.from_pretrained(tmpdirname) + global_rng = random.Random()