Keep relevant weights in fp32 when model._keep_in_fp32_modules is set even when accelerate is not installed (#26225)

* fix bug where weight would not be kept in fp32

* nit

* address review comments

* fix test
This commit is contained in:
fxmarty
2023-09-21 19:00:03 +09:00
committed by GitHub
parent e3a4bd2bee
commit da971b2271
2 changed files with 33 additions and 25 deletions

View File

@@ -1046,15 +1046,30 @@ class T5ModelFp16Tests(unittest.TestCase):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
orig_import = __import__
accelerate_mock = unittest.mock.Mock()
# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# mock import of accelerate
def import_accelerate_mock(name, *args, **kwargs):
if name == "accelerate":
if accelerate_available:
return accelerate_mock
else:
raise ImportError
return orig_import(name, *args, **kwargs)
# Load without using `accelerate`
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
accelerate_available = False
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")