From ae9230af40ecc8ccb940765830b2f8727049a845 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 28 Feb 2023 15:09:44 +0100 Subject: [PATCH] [`T5`] Fix torchquant issue (#21843) * fix torchquant issue * add tests --- src/transformers/models/longt5/modeling_longt5.py | 6 +++++- src/transformers/models/mt5/modeling_mt5.py | 12 ++++++++++-- .../modeling_switch_transformers.py | 6 +++++- src/transformers/models/t5/modeling_t5.py | 12 ++++++++++-- tests/models/t5/test_modeling_t5.py | 13 +++++++++++++ 5 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 316781c623..366fbc4f7c 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -275,7 +275,11 @@ class LongT5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) - if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 771850690e..951a68cb76 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -145,7 +145,11 @@ class MT5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) - if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states @@ -170,7 +174,11 @@ class MT5DenseGatedActDense(nn.Module): # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 61c232b5cc..de24797c67 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -272,7 +272,11 @@ class SwitchTransformersDenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) - if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 648fe45398..9768998631 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -288,7 +288,11 @@ class T5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) - if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states @@ -312,7 +316,11 @@ class T5DenseGatedActDense(nn.Module): # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index c6c0ede071..8833898649 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase): def tokenizer(self): return T5Tokenizer.from_pretrained("t5-base") + @slow + def test_torch_quant(self): + r""" + Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model. + """ + model_name = "google/flan-t5-small" + tokenizer = T5Tokenizer.from_pretrained(model_name) + model = T5ForConditionalGeneration.from_pretrained(model_name) + model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) + input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" + input_ids = tokenizer(input_text, return_tensors="pt").input_ids + _ = model.generate(input_ids) + @slow def test_small_generation(self): model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)