add _keep_in_fp32_modules_strict (#39058)
* add _keep_in_fp32_modules_strict * complete test
This commit is contained in:
@@ -30,6 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_accelerate,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
@@ -615,6 +616,81 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@slow
|
||||
class KyutaiSpeechToTextBf16Test(unittest.TestCase):
|
||||
def test_bf16_fp32_conversion(self):
|
||||
r"""
|
||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||
"""
|
||||
model_checkpoint = "kyutai/stt-2.6b-en-trfs"
|
||||
orig_import = __import__
|
||||
accelerate_mock = unittest.mock.Mock()
|
||||
|
||||
# mock import of accelerate
|
||||
def import_accelerate_mock(name, *args, **kwargs):
|
||||
if name == "accelerate":
|
||||
if accelerate_available:
|
||||
return accelerate_mock
|
||||
else:
|
||||
raise ImportError
|
||||
return orig_import(name, *args, **kwargs)
|
||||
|
||||
# Load without using `accelerate`
|
||||
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
|
||||
accelerate_available = False
|
||||
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.float16
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.float16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||
|
||||
# Load without in bf16
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||
|
||||
# Load without using `accelerate`
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.float16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||
|
||||
# Load using `accelerate`
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.float16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
|
||||
_dataset = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user