@@ -275,7 +275,11 @@ class LongT5DenseActDense(nn.Module):
|
|||||||
hidden_states = self.wi(hidden_states)
|
hidden_states = self.wi(hidden_states)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
if (
|
||||||
|
isinstance(self.wo.weight, torch.Tensor)
|
||||||
|
and hidden_states.dtype != self.wo.weight.dtype
|
||||||
|
and self.wo.weight.dtype != torch.int8
|
||||||
|
):
|
||||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -145,7 +145,11 @@ class MT5DenseActDense(nn.Module):
|
|||||||
hidden_states = self.wi(hidden_states)
|
hidden_states = self.wi(hidden_states)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
if (
|
||||||
|
isinstance(self.wo.weight, torch.Tensor)
|
||||||
|
and hidden_states.dtype != self.wo.weight.dtype
|
||||||
|
and self.wo.weight.dtype != torch.int8
|
||||||
|
):
|
||||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -170,7 +174,11 @@ class MT5DenseGatedActDense(nn.Module):
|
|||||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
||||||
# See https://github.com/huggingface/transformers/issues/20287
|
# See https://github.com/huggingface/transformers/issues/20287
|
||||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
||||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
if (
|
||||||
|
isinstance(self.wo.weight, torch.Tensor)
|
||||||
|
and hidden_states.dtype != self.wo.weight.dtype
|
||||||
|
and self.wo.weight.dtype != torch.int8
|
||||||
|
):
|
||||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||||
|
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
|
|||||||
@@ -272,7 +272,11 @@ class SwitchTransformersDenseActDense(nn.Module):
|
|||||||
hidden_states = self.wi(hidden_states)
|
hidden_states = self.wi(hidden_states)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
if (
|
||||||
|
isinstance(self.wo.weight, torch.Tensor)
|
||||||
|
and hidden_states.dtype != self.wo.weight.dtype
|
||||||
|
and self.wo.weight.dtype != torch.int8
|
||||||
|
):
|
||||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -288,7 +288,11 @@ class T5DenseActDense(nn.Module):
|
|||||||
hidden_states = self.wi(hidden_states)
|
hidden_states = self.wi(hidden_states)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
if (
|
||||||
|
isinstance(self.wo.weight, torch.Tensor)
|
||||||
|
and hidden_states.dtype != self.wo.weight.dtype
|
||||||
|
and self.wo.weight.dtype != torch.int8
|
||||||
|
):
|
||||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -312,7 +316,11 @@ class T5DenseGatedActDense(nn.Module):
|
|||||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
||||||
# See https://github.com/huggingface/transformers/issues/20287
|
# See https://github.com/huggingface/transformers/issues/20287
|
||||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
||||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
if (
|
||||||
|
isinstance(self.wo.weight, torch.Tensor)
|
||||||
|
and hidden_states.dtype != self.wo.weight.dtype
|
||||||
|
and self.wo.weight.dtype != torch.int8
|
||||||
|
):
|
||||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||||
|
|
||||||
hidden_states = self.wo(hidden_states)
|
hidden_states = self.wo(hidden_states)
|
||||||
|
|||||||
@@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||||||
def tokenizer(self):
|
def tokenizer(self):
|
||||||
return T5Tokenizer.from_pretrained("t5-base")
|
return T5Tokenizer.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_torch_quant(self):
|
||||||
|
r"""
|
||||||
|
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
|
||||||
|
"""
|
||||||
|
model_name = "google/flan-t5-small"
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||||
|
input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
||||||
|
_ = model.generate(input_ids)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_small_generation(self):
|
def test_small_generation(self):
|
||||||
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user