From 66db33ddc8e8405bcffcdaf463242cd788dbb54d Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:29:51 +0100 Subject: [PATCH] Fix mismatching loading in from_pretrained with/without accelerate (#28414) * fix mismatching behavior in from_pretrained with/without accelerate * meaningful refactor * remove added space * add test * fix model on the hub * comment * use tiny model * style --- src/transformers/modeling_utils.py | 23 ++++++++++++++--------- tests/test_modeling_utils.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 76ff2db343..e863d880e7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -756,18 +756,23 @@ def _load_state_dict_into_meta_model( else: param = param.to(dtype) - # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model - if dtype is None: - old_param = model - splits = param_name.split(".") - for split in splits: - old_param = getattr(old_param, split) - if old_param is None: - break + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + if old_param is None: + break - if old_param is not None: + if old_param is not None: + if dtype is None: param = param.to(old_param.dtype) + if old_param.is_contiguous(): + param = param.contiguous() + set_module_kwargs["value"] = param if device_map is None: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ec72cdab82..6e16b184b2 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -34,6 +34,7 @@ from requests.exceptions import HTTPError from transformers import ( AutoConfig, AutoModel, + OwlViTForObjectDetection, PretrainedConfig, is_torch_available, logging, @@ -835,6 +836,23 @@ class ModelUtilsTest(TestCasePlus): outputs2 = new_model_with_offload(inputs) self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu())) + @slow + @require_torch + def test_from_pretrained_non_contiguous_checkpoint(self): + # See: https://github.com/huggingface/transformers/pull/28414 + # Tiny models on the Hub have contiguous weights, contrarily to google/owlvit + model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight") + self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous()) + + model = OwlViTForObjectDetection.from_pretrained( + "fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto" + ) + self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous()) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=False) + model.save_pretrained(tmp_dir, safe_serialization=True) + def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock()