Add keep_in_fp32_modules support (#20683)
* add `keep_in_fp32_modules` support * pass it as class attribute * few modifs - make tests `slow` - fix logic * better logic * fix failing test * `bfloat16` support * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * simplify tests * simplify tests * fix test * modify message * more checks * fix failing tests * add more conditions - add `is_accelerate_available` - fixes pipleine tests that failed * add suggestions * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix failing `bnb` test * add last safety checker Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -155,6 +155,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.float()
|
||||
|
||||
def test_fp32_int8_conversion(self):
|
||||
r"""
|
||||
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
|
||||
"""
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
||||
@@ -19,7 +19,14 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import T5Config, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -820,6 +827,50 @@ def use_task_specific_params(model, task):
|
||||
model.config.update(model.config.task_specific_params[task])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class T5ModelFp16Tests(unittest.TestCase):
|
||||
def test_fp16_fp32_conversion(self):
|
||||
r"""
|
||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||
"""
|
||||
# Load without using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
# Load without in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||
)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load without using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
# Load using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
Reference in New Issue
Block a user