Nail in edge case of torch dtype being overriden permantly in the case of an error (#35845)
* Nail in edge case of torch dtype * Rm unused func * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Refactor tests to only mock what we need, don't introduce injection functions * SetUp/TearDown * Do super --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@@ -246,6 +246,25 @@ def set_zero3_state():
|
|||||||
_is_ds_init_called = False
|
_is_ds_init_called = False
|
||||||
|
|
||||||
|
|
||||||
|
def restore_default_torch_dtype(func):
|
||||||
|
"""
|
||||||
|
Decorator to restore the default torch dtype
|
||||||
|
at the end of the function. Serves
|
||||||
|
as a backup in case calling the function raises
|
||||||
|
an error after the function has changed the default dtype but before it could restore it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def _wrapper(*args, **kwargs):
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||||
try:
|
try:
|
||||||
return next(parameter.parameters()).device
|
return next(parameter.parameters()).device
|
||||||
@@ -1407,6 +1426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
self.model_tags.append(tag)
|
self.model_tags.append(tag)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@restore_default_torch_dtype
|
||||||
def _from_config(cls, config, **kwargs):
|
def _from_config(cls, config, **kwargs):
|
||||||
"""
|
"""
|
||||||
All context managers that the model should be initialized under go here.
|
All context managers that the model should be initialized under go here.
|
||||||
@@ -3142,6 +3162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return super().float(*args)
|
return super().float(*args)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@restore_default_torch_dtype
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls: Type[SpecificPreTrainedModelType],
|
cls: Type[SpecificPreTrainedModelType],
|
||||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from transformers import (
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
|
MistralForCausalLM,
|
||||||
OwlViTForObjectDetection,
|
OwlViTForObjectDetection,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -318,6 +319,14 @@ def check_models_equal(model1, model2):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModelUtilsTest(TestCasePlus):
|
class ModelUtilsTest(TestCasePlus):
|
||||||
|
def setUp(self):
|
||||||
|
self.old_dtype = torch.get_default_dtype()
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
torch.set_default_dtype(self.old_dtype)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "google-bert/bert-base-uncased"
|
model_name = "google-bert/bert-base-uncased"
|
||||||
@@ -1819,6 +1828,67 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertIsNone(model_outputs.past_key_values)
|
self.assertIsNone(model_outputs.past_key_values)
|
||||||
self.assertTrue(model.training)
|
self.assertTrue(model.training)
|
||||||
|
|
||||||
|
def test_restore_default_torch_dtype_from_pretrained(self):
|
||||||
|
"""
|
||||||
|
Tests that the default torch dtype is restored
|
||||||
|
when an error happens during the loading of a model.
|
||||||
|
"""
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
# set default type to float32
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
# Mock injection point which is right after the call to `_set_default_torch_dtype`
|
||||||
|
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
|
||||||
|
|
||||||
|
def debug(*args, **kwargs):
|
||||||
|
# call the method as usual, than raise a RuntimeError
|
||||||
|
original_set_default_torch_dtype(*args, **kwargs)
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
|
||||||
|
side_effect=debug,
|
||||||
|
):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
_ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16)
|
||||||
|
# default should still be float32
|
||||||
|
assert torch.get_default_dtype() == torch.float32
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
def test_restore_default_torch_dtype_from_config(self):
|
||||||
|
"""
|
||||||
|
Tests that the default torch dtype is restored
|
||||||
|
when an error happens during the loading of a model.
|
||||||
|
"""
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
# set default type to float32
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
TINY_MISTRAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock injection point which is right after the call to `_set_default_torch_dtype`
|
||||||
|
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype
|
||||||
|
|
||||||
|
def debug(*args, **kwargs):
|
||||||
|
# call the method as usual, than raise a RuntimeError
|
||||||
|
original_set_default_torch_dtype(*args, **kwargs)
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype",
|
||||||
|
side_effect=debug,
|
||||||
|
):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
config.torch_dtype = torch.float16
|
||||||
|
_ = AutoModelForCausalLM.from_config(
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
# default should still be float32
|
||||||
|
assert torch.get_default_dtype() == torch.float32
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
def test_unknown_quantization_config(self):
|
def test_unknown_quantization_config(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
|
|||||||
Reference in New Issue
Block a user