From e2e393c6f25205739b5dc9fddd460d7bfab85150 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 24 Jan 2023 18:14:38 +0100 Subject: [PATCH] [`t5`] Fix T5 inference in `float16` + `bnb` error (#21281) * attempts to fix: - upcast input for `T5DenseActDense` - add the condition `self.wo.weight.dtype != torch.int8` - added tests on `test/mixed_int8` - `make fixup` * fix ci test --- .../models/longt5/modeling_longt5.py | 2 + src/transformers/models/mt5/modeling_mt5.py | 5 +- .../modeling_switch_transformers.py | 2 + src/transformers/models/t5/modeling_t5.py | 5 +- tests/mixed_int8/test_mixed_int8.py | 64 +++++++++++++++++++ 5 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 2101d247c1..486eda3876 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -276,6 +276,8 @@ class LongT5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) + if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 1cada1b235..dac08695b6 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -146,6 +146,8 @@ class MT5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) + if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states @@ -168,7 +170,8 @@ class MT5DenseGatedActDense(nn.Module): # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 - if hidden_states.dtype != self.wo.weight.dtype: + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 42aae23014..b19016bfab 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -273,6 +273,8 @@ class SwitchTransformersDenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) + if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 592f33cf22..f329f32d42 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -289,6 +289,8 @@ class T5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) + if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states @@ -310,7 +312,8 @@ class T5DenseGatedActDense(nn.Module): # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 - if hidden_states.dtype != self.wo.weight.dtype: + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 56ce10638d..b1e8ab1a33 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -163,6 +163,70 @@ class MixedInt8Test(BaseMixedInt8Test): self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) +@require_bitsandbytes +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class MixedInt8T5Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_name = "t5-small" + cls.dense_act_model_name = "google/flan-t5-small" # flan-t5 uses dense-act instead of dense-relu-dense + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.input_text = "Translate in German: Hello, my dog is cute" + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + gc.collect() + torch.cuda.empty_cache() + + def test_inference_without_keep_in_fp32(self): + r""" + Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly. + `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test + both cases. + """ + from transformers import T5ForConditionalGeneration + + T5ForConditionalGeneration._keep_in_fp32_modules = None + + # test with `t5-small` + model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + _ = model.generate(**encoded_input) + + # test with `flan-t5-small` + model = T5ForConditionalGeneration.from_pretrained( + self.dense_act_model_name, load_in_8bit=True, device_map="auto" + ) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + _ = model.generate(**encoded_input) + + def test_inference_with_keep_in_fp32(self): + r""" + Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly. + `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test + both cases. + """ + from transformers import T5ForConditionalGeneration + + # test with `t5-small` + model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + _ = model.generate(**encoded_input) + + # test with `flan-t5-small` + model = T5ForConditionalGeneration.from_pretrained( + self.dense_act_model_name, load_in_8bit=True, device_map="auto" + ) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + _ = model.generate(**encoded_input) + + class MixedInt8ModelClassesTest(BaseMixedInt8Test): def setUp(self): super().setUp()