Small fix on context manager detection (#37562)
* small fixes * Update modeling_utils.py * test * Update test_modeling_common.py * Update test_modeling_timm_backbone.py * more general * simpler
This commit is contained in:
@@ -177,7 +177,7 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
|
||||
def test_cannot_load_with_meta_device_context_manager(self):
|
||||
def test_can_load_with_meta_device_context_manager(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
||||
|
||||
@@ -4590,7 +4590,7 @@ class ModelTesterMixin:
|
||||
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
|
||||
)
|
||||
|
||||
def test_cannot_load_with_meta_device_context_manager(self):
|
||||
def test_can_load_with_meta_device_context_manager(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
|
||||
@@ -4600,10 +4600,17 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# This should raise an error with meta device
|
||||
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
|
||||
with torch.device("meta"):
|
||||
_ = model_class.from_pretrained(tmpdirname)
|
||||
with torch.device("meta"):
|
||||
new_model = model_class.from_pretrained(tmpdirname)
|
||||
unique_devices = {param.device for param in new_model.parameters()} | {
|
||||
buffer.device for buffer in new_model.buffers()
|
||||
}
|
||||
|
||||
self.assertEqual(
|
||||
unique_devices,
|
||||
{torch.device("meta")},
|
||||
f"All parameters should be on meta device, but found {unique_devices}.",
|
||||
)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
Reference in New Issue
Block a user