From e6d250e4cdc1dc800ff7d4056a297a6219f44812 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:44:31 +0200 Subject: [PATCH] [`core`] fix silent bug `keep_in_fp32` modules (#26589) * fix silent bug `keep_in_fp32` modules * final fix * added a common test. * Trigger CI * revert --- src/transformers/modeling_utils.py | 10 +++++++--- .../instructblip/test_modeling_instructblip.py | 3 ++- tests/test_modeling_common.py | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5c3a121836..54f31ab926 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -693,7 +693,9 @@ def _load_state_dict_into_meta_model( if dtype is not None and torch.is_floating_point(param): if ( keep_in_fp32_modules is not None - and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules) + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) and dtype == torch.float16 ): param = param.to(torch.float32) @@ -3534,7 +3536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if ( keep_in_fp32_modules is not None and dtype == torch.float16 - and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules) + and any( + module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) ): target_dtype = torch.float32 @@ -3561,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Set some modules to fp32 if any if keep_in_fp32_modules is not None: for name, param in model.named_parameters(): - if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): # param = param.to(torch.float32) does not work here as only in the local scope. param.data = param.data.to(torch.float32) diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index 1c8af01118..f0fd193b64 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -533,7 +533,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase): def test_inference_vicuna_7b(self): processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b") model = InstructBlipForConditionalGeneration.from_pretrained( - "Salesforce/instructblip-vicuna-7b", load_in_8bit=True + "Salesforce/instructblip-vicuna-7b", load_in_8bit=True, low_cpu_mem_usage=True ) url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg" @@ -569,6 +569,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase): model = InstructBlipForConditionalGeneration.from_pretrained( "Salesforce/instructblip-flan-t5-xl", torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, ).to(torch_device) url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0a17c13a01..5a239cf0fb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -275,6 +275,24 @@ class ModelTesterMixin: for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_keep_in_fp32_modules(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + if model_class._keep_in_fp32_modules is None: + return + + model = model_class(config) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16) + + for name, param in model.named_parameters(): + if any(n in model_class._keep_in_fp32_modules for n in name.split(".")): + self.assertTrue(param.dtype == torch.float32) + else: + self.assertTrue(param.dtype == torch.float16, name) + def test_save_load_keys_to_ignore_on_save(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()