Keep relevant weights in fp32 when model._keep_in_fp32_modules is set even when accelerate is not installed (#26225)
* fix bug where weight would not be kept in fp32 * nit * address review comments * fix test
This commit is contained in:
@@ -1046,15 +1046,30 @@ class T5ModelFp16Tests(unittest.TestCase):
|
||||
r"""
|
||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||
"""
|
||||
# Load without using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
orig_import = __import__
|
||||
accelerate_mock = unittest.mock.Mock()
|
||||
|
||||
# Load without in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
# mock import of accelerate
|
||||
def import_accelerate_mock(name, *args, **kwargs):
|
||||
if name == "accelerate":
|
||||
if accelerate_available:
|
||||
return accelerate_mock
|
||||
else:
|
||||
raise ImportError
|
||||
return orig_import(name, *args, **kwargs)
|
||||
|
||||
# Load without using `accelerate`
|
||||
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
|
||||
accelerate_available = False
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
# Load without in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
|
||||
Reference in New Issue
Block a user