Nail in edge case of torch dtype being overriden permantly in the case of an error (#35845)
* Nail in edge case of torch dtype * Rm unused func * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Refactor tests to only mock what we need, don't introduce injection functions * SetUp/TearDown * Do super --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@@ -39,6 +39,7 @@ from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DynamicCache,
|
||||
LlavaForConditionalGeneration,
|
||||
MistralForCausalLM,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
@@ -318,6 +319,14 @@ def check_models_equal(model1, model2):
|
||||
|
||||
@require_torch
|
||||
class ModelUtilsTest(TestCasePlus):
|
||||
def setUp(self):
|
||||
self.old_dtype = torch.get_default_dtype()
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
torch.set_default_dtype(self.old_dtype)
|
||||
super().tearDown()
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "google-bert/bert-base-uncased"
|
||||
@@ -1819,6 +1828,67 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertIsNone(model_outputs.past_key_values)
|
||||
self.assertTrue(model.training)
|
||||
|
||||
def test_restore_default_torch_dtype_from_pretrained(self):
|
||||
"""
|
||||
Tests that the default torch dtype is restored
|
||||
when an error happens during the loading of a model.
|
||||
"""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
# set default type to float32
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# Mock injection point which is right after the call to `_set_default_torch_dtype`
|
||||
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
# call the method as usual, than raise a RuntimeError
|
||||
original_set_default_torch_dtype(*args, **kwargs)
|
||||
raise RuntimeError
|
||||
|
||||
with mock.patch(
|
||||
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
|
||||
side_effect=debug,
|
||||
):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16)
|
||||
# default should still be float32
|
||||
assert torch.get_default_dtype() == torch.float32
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
def test_restore_default_torch_dtype_from_config(self):
|
||||
"""
|
||||
Tests that the default torch dtype is restored
|
||||
when an error happens during the loading of a model.
|
||||
"""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
# set default type to float32
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
TINY_MISTRAL,
|
||||
)
|
||||
|
||||
# Mock injection point which is right after the call to `_set_default_torch_dtype`
|
||||
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
# call the method as usual, than raise a RuntimeError
|
||||
original_set_default_torch_dtype(*args, **kwargs)
|
||||
raise RuntimeError
|
||||
|
||||
with mock.patch(
|
||||
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
|
||||
side_effect=debug,
|
||||
):
|
||||
with self.assertRaises(RuntimeError):
|
||||
config.torch_dtype = torch.float16
|
||||
_ = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
)
|
||||
# default should still be float32
|
||||
assert torch.get_default_dtype() == torch.float32
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
def test_unknown_quantization_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = BertConfig(
|
||||
|
||||
Reference in New Issue
Block a user