Nail in edge case of torch dtype being overriden permantly in the case of an error (#35845)

* Nail in edge case of torch dtype

* Rm unused func

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Refactor tests to only mock what we need, don't introduce injection functions

* SetUp/TearDown

* Do super

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2025-02-06 09:05:23 -05:00
committed by GitHub
parent e3458af726
commit 1ce0e2992e
2 changed files with 91 additions and 0 deletions

View File

@@ -39,6 +39,7 @@ from transformers import (
AutoModelForSequenceClassification,
DynamicCache,
LlavaForConditionalGeneration,
MistralForCausalLM,
OwlViTForObjectDetection,
PretrainedConfig,
is_torch_available,
@@ -318,6 +319,14 @@ def check_models_equal(model1, model2):
@require_torch
class ModelUtilsTest(TestCasePlus):
def setUp(self):
self.old_dtype = torch.get_default_dtype()
super().setUp()
def tearDown(self):
torch.set_default_dtype(self.old_dtype)
super().tearDown()
@slow
def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased"
@@ -1819,6 +1828,67 @@ class ModelUtilsTest(TestCasePlus):
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)
def test_restore_default_torch_dtype_from_pretrained(self):
"""
Tests that the default torch dtype is restored
when an error happens during the loading of a model.
"""
old_dtype = torch.get_default_dtype()
# set default type to float32
torch.set_default_dtype(torch.float32)
# Mock injection point which is right after the call to `_set_default_torch_dtype`
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
def debug(*args, **kwargs):
# call the method as usual, than raise a RuntimeError
original_set_default_torch_dtype(*args, **kwargs)
raise RuntimeError
with mock.patch(
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
side_effect=debug,
):
with self.assertRaises(RuntimeError):
_ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16)
# default should still be float32
assert torch.get_default_dtype() == torch.float32
torch.set_default_dtype(old_dtype)
def test_restore_default_torch_dtype_from_config(self):
"""
Tests that the default torch dtype is restored
when an error happens during the loading of a model.
"""
old_dtype = torch.get_default_dtype()
# set default type to float32
torch.set_default_dtype(torch.float32)
config = AutoConfig.from_pretrained(
TINY_MISTRAL,
)
# Mock injection point which is right after the call to `_set_default_torch_dtype`
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
def debug(*args, **kwargs):
# call the method as usual, than raise a RuntimeError
original_set_default_torch_dtype(*args, **kwargs)
raise RuntimeError
with mock.patch(
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
side_effect=debug,
):
with self.assertRaises(RuntimeError):
config.torch_dtype = torch.float16
_ = AutoModelForCausalLM.from_config(
config,
)
# default should still be float32
assert torch.get_default_dtype() == torch.float32
torch.set_default_dtype(old_dtype)
def test_unknown_quantization_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = BertConfig(