[T5] Fix torchquant issue (#21843)

* fix torchquant issue

* add tests
This commit is contained in:
Younes Belkada
2023-02-28 15:09:44 +01:00
committed by GitHub
parent 2d506ea4c4
commit ae9230af40
5 changed files with 43 additions and 6 deletions

View File

@@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_torch_quant(self):
r"""
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
"""
model_name = "google/flan-t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
_ = model.generate(input_ids)
@slow
def test_small_generation(self):
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)