Add keep_in_fp32_modules support (#20683)

* add `keep_in_fp32_modules` support

* pass it as class attribute

* few modifs

- make tests `slow`
- fix logic

* better logic

* fix failing test

* `bfloat16` support

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix

* simplify tests

* simplify tests

* fix test

* modify message

* more checks

* fix failing tests

* add more conditions

- add `is_accelerate_available`
- fixes pipleine tests that failed

* add suggestions

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix failing `bnb` test

* add last safety checker

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2022-12-13 11:59:57 +01:00
committed by GitHub
parent d4bf9ee1ff
commit 1af4bee896
5 changed files with 127 additions and 6 deletions

View File

@@ -19,7 +19,14 @@ import tempfile
import unittest
from transformers import T5Config, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
@@ -820,6 +827,50 @@ def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
@require_torch
@require_accelerate
@require_tokenizers
@slow
class T5ModelFp16Tests(unittest.TestCase):
def test_fp16_fp32_conversion(self):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
# Load using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
@require_torch
@require_sentencepiece
@require_tokenizers