@@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
def tokenizer(self):
|
||||
return T5Tokenizer.from_pretrained("t5-base")
|
||||
|
||||
@slow
|
||||
def test_torch_quant(self):
|
||||
r"""
|
||||
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
|
||||
"""
|
||||
model_name = "google/flan-t5-small"
|
||||
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
||||
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||
input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
||||
_ = model.generate(input_ids)
|
||||
|
||||
@slow
|
||||
def test_small_generation(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user