Fix mismatching loading in from_pretrained with/without accelerate (#28414)
* fix mismatching behavior in from_pretrained with/without accelerate * meaningful refactor * remove added space * add test * fix model on the hub * comment * use tiny model * style
This commit is contained in:
@@ -756,18 +756,23 @@ def _load_state_dict_into_meta_model(
|
|||||||
else:
|
else:
|
||||||
param = param.to(dtype)
|
param = param.to(dtype)
|
||||||
|
|
||||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||||
if dtype is None:
|
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||||
old_param = model
|
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||||
splits = param_name.split(".")
|
old_param = model
|
||||||
for split in splits:
|
splits = param_name.split(".")
|
||||||
old_param = getattr(old_param, split)
|
for split in splits:
|
||||||
if old_param is None:
|
old_param = getattr(old_param, split)
|
||||||
break
|
if old_param is None:
|
||||||
|
break
|
||||||
|
|
||||||
if old_param is not None:
|
if old_param is not None:
|
||||||
|
if dtype is None:
|
||||||
param = param.to(old_param.dtype)
|
param = param.to(old_param.dtype)
|
||||||
|
|
||||||
|
if old_param.is_contiguous():
|
||||||
|
param = param.contiguous()
|
||||||
|
|
||||||
set_module_kwargs["value"] = param
|
set_module_kwargs["value"] = param
|
||||||
|
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from requests.exceptions import HTTPError
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
OwlViTForObjectDetection,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -835,6 +836,23 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
outputs2 = new_model_with_offload(inputs)
|
outputs2 = new_model_with_offload(inputs)
|
||||||
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_from_pretrained_non_contiguous_checkpoint(self):
|
||||||
|
# See: https://github.com/huggingface/transformers/pull/28414
|
||||||
|
# Tiny models on the Hub have contiguous weights, contrarily to google/owlvit
|
||||||
|
model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight")
|
||||||
|
self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
|
||||||
|
|
||||||
|
model = OwlViTForObjectDetection.from_pretrained(
|
||||||
|
"fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto"
|
||||||
|
)
|
||||||
|
self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
|
|||||||
Reference in New Issue
Block a user