Keep relevant weights in fp32 when model._keep_in_fp32_modules is set even when accelerate is not installed (#26225)
* fix bug where weight would not be kept in fp32 * nit * address review comments * fix test
This commit is contained in:
@@ -2950,26 +2950,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
# Check if `_keep_in_fp32_modules` is not None
|
# Check if `_keep_in_fp32_modules` is not None
|
||||||
use_keep_in_fp32_modules = (
|
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||||
(cls._keep_in_fp32_modules is not None)
|
torch_dtype == torch.float16 or load_in_4bit or load_in_8bit
|
||||||
and is_accelerate_available()
|
|
||||||
and (torch_dtype == torch.float16 or load_in_4bit or load_in_8bit)
|
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
(cls._keep_in_fp32_modules is not None)
|
|
||||||
and not is_accelerate_available()
|
|
||||||
and torch_dtype == torch.float16
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"For stability purposes, it is recommended to have accelerate installed when using this model in"
|
|
||||||
" torch.float16, please install it with `pip install accelerate`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
else:
|
else:
|
||||||
loaded_state_dict_keys = list(state_dict.keys())
|
loaded_state_dict_keys = list(state_dict.keys())
|
||||||
if low_cpu_mem_usage or use_keep_in_fp32_modules:
|
if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()):
|
||||||
|
# In case some weights need to be kept in float32 and accelerate is not installed,
|
||||||
|
# we later on want to take the path where state_dict is not None, that is the one
|
||||||
|
# that do not require accelerate.
|
||||||
state_dict = None
|
state_dict = None
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
@@ -2990,7 +2982,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Check first if we are `from_pt`
|
# Check first if we are `from_pt`
|
||||||
if use_keep_in_fp32_modules:
|
if use_keep_in_fp32_modules:
|
||||||
low_cpu_mem_usage = True
|
if is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = True
|
||||||
keep_in_fp32_modules = model._keep_in_fp32_modules
|
keep_in_fp32_modules = model._keep_in_fp32_modules
|
||||||
else:
|
else:
|
||||||
keep_in_fp32_modules = []
|
keep_in_fp32_modules = []
|
||||||
@@ -3465,7 +3458,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if keep_in_fp32_modules is not None:
|
if keep_in_fp32_modules is not None:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||||
param = param.to(torch.float32)
|
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
# Make sure we are able to load base models as well as derived models (with heads)
|
# Make sure we are able to load base models as well as derived models (with heads)
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
@@ -3592,7 +3586,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
ignore_mismatched_sizes,
|
ignore_mismatched_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
|
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
|
||||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||||
|
|||||||
@@ -1046,15 +1046,30 @@ class T5ModelFp16Tests(unittest.TestCase):
|
|||||||
r"""
|
r"""
|
||||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||||
"""
|
"""
|
||||||
# Load without using `accelerate`
|
orig_import = __import__
|
||||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
accelerate_mock = unittest.mock.Mock()
|
||||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
|
||||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
|
||||||
|
|
||||||
# Load without in bf16
|
# mock import of accelerate
|
||||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
|
def import_accelerate_mock(name, *args, **kwargs):
|
||||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
if name == "accelerate":
|
||||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
if accelerate_available:
|
||||||
|
return accelerate_mock
|
||||||
|
else:
|
||||||
|
raise ImportError
|
||||||
|
return orig_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
# Load without using `accelerate`
|
||||||
|
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
|
||||||
|
accelerate_available = False
|
||||||
|
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
||||||
|
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||||
|
|
||||||
|
# Load without in bf16
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
|
||||||
|
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||||
|
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||||
|
|
||||||
# Load using `accelerate` in bf16
|
# Load using `accelerate` in bf16
|
||||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
|
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
|
||||||
|
|||||||
Reference in New Issue
Block a user