[core] fix silent bug keep_in_fp32 modules (#26589)
* fix silent bug `keep_in_fp32` modules * final fix * added a common test. * Trigger CI * revert
This commit is contained in:
@@ -693,7 +693,9 @@ def _load_state_dict_into_meta_model(
|
|||||||
if dtype is not None and torch.is_floating_point(param):
|
if dtype is not None and torch.is_floating_point(param):
|
||||||
if (
|
if (
|
||||||
keep_in_fp32_modules is not None
|
keep_in_fp32_modules is not None
|
||||||
and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules)
|
and any(
|
||||||
|
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
||||||
|
)
|
||||||
and dtype == torch.float16
|
and dtype == torch.float16
|
||||||
):
|
):
|
||||||
param = param.to(torch.float32)
|
param = param.to(torch.float32)
|
||||||
@@ -3534,7 +3536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if (
|
if (
|
||||||
keep_in_fp32_modules is not None
|
keep_in_fp32_modules is not None
|
||||||
and dtype == torch.float16
|
and dtype == torch.float16
|
||||||
and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules)
|
and any(
|
||||||
|
module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
||||||
|
)
|
||||||
):
|
):
|
||||||
target_dtype = torch.float32
|
target_dtype = torch.float32
|
||||||
|
|
||||||
@@ -3561,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Set some modules to fp32 if any
|
# Set some modules to fp32 if any
|
||||||
if keep_in_fp32_modules is not None:
|
if keep_in_fp32_modules is not None:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
|||||||
@@ -533,7 +533,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
|||||||
def test_inference_vicuna_7b(self):
|
def test_inference_vicuna_7b(self):
|
||||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
||||||
model = InstructBlipForConditionalGeneration.from_pretrained(
|
model = InstructBlipForConditionalGeneration.from_pretrained(
|
||||||
"Salesforce/instructblip-vicuna-7b", load_in_8bit=True
|
"Salesforce/instructblip-vicuna-7b", load_in_8bit=True, low_cpu_mem_usage=True
|
||||||
)
|
)
|
||||||
|
|
||||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
||||||
@@ -569,6 +569,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
|||||||
model = InstructBlipForConditionalGeneration.from_pretrained(
|
model = InstructBlipForConditionalGeneration.from_pretrained(
|
||||||
"Salesforce/instructblip-flan-t5-xl",
|
"Salesforce/instructblip-flan-t5-xl",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
|
|
||||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
||||||
|
|||||||
@@ -275,6 +275,24 @@ class ModelTesterMixin:
|
|||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_keep_in_fp32_modules(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class._keep_in_fp32_modules is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(n in model_class._keep_in_fp32_modules for n in name.split(".")):
|
||||||
|
self.assertTrue(param.dtype == torch.float32)
|
||||||
|
else:
|
||||||
|
self.assertTrue(param.dtype == torch.float16, name)
|
||||||
|
|
||||||
def test_save_load_keys_to_ignore_on_save(self):
|
def test_save_load_keys_to_ignore_on_save(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user