Keep relevant weights in fp32 when model._keep_in_fp32_modules is set even when accelerate is not installed (#26225)
* fix bug where weight would not be kept in fp32 * nit * address review comments * fix test
This commit is contained in:
@@ -2950,26 +2950,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (
|
||||
(cls._keep_in_fp32_modules is not None)
|
||||
and is_accelerate_available()
|
||||
and (torch_dtype == torch.float16 or load_in_4bit or load_in_8bit)
|
||||
)
|
||||
if (
|
||||
(cls._keep_in_fp32_modules is not None)
|
||||
and not is_accelerate_available()
|
||||
and torch_dtype == torch.float16
|
||||
):
|
||||
logger.warning(
|
||||
"For stability purposes, it is recommended to have accelerate installed when using this model in"
|
||||
" torch.float16, please install it with `pip install accelerate`"
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
torch_dtype == torch.float16 or load_in_4bit or load_in_8bit
|
||||
)
|
||||
|
||||
if is_sharded:
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
else:
|
||||
loaded_state_dict_keys = list(state_dict.keys())
|
||||
if low_cpu_mem_usage or use_keep_in_fp32_modules:
|
||||
if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()):
|
||||
# In case some weights need to be kept in float32 and accelerate is not installed,
|
||||
# we later on want to take the path where state_dict is not None, that is the one
|
||||
# that do not require accelerate.
|
||||
state_dict = None
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
@@ -2990,6 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Check first if we are `from_pt`
|
||||
if use_keep_in_fp32_modules:
|
||||
if is_accelerate_available():
|
||||
low_cpu_mem_usage = True
|
||||
keep_in_fp32_modules = model._keep_in_fp32_modules
|
||||
else:
|
||||
@@ -3465,7 +3458,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if keep_in_fp32_modules is not None:
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
param = param.to(torch.float32)
|
||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
start_prefix = ""
|
||||
@@ -3592,7 +3586,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
|
||||
@@ -1046,7 +1046,22 @@ class T5ModelFp16Tests(unittest.TestCase):
|
||||
r"""
|
||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||
"""
|
||||
orig_import = __import__
|
||||
accelerate_mock = unittest.mock.Mock()
|
||||
|
||||
# mock import of accelerate
|
||||
def import_accelerate_mock(name, *args, **kwargs):
|
||||
if name == "accelerate":
|
||||
if accelerate_available:
|
||||
return accelerate_mock
|
||||
else:
|
||||
raise ImportError
|
||||
return orig_import(name, *args, **kwargs)
|
||||
|
||||
# Load without using `accelerate`
|
||||
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
|
||||
accelerate_available = False
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
Reference in New Issue
Block a user