[T5 fp16] Fix fp16 in T5 (#4436)
* fix fp16 in t5 * make style * refactor invert_attention_mask fn * fix typo
This commit is contained in:
committed by
GitHub
parent
fa6113f9a0
commit
026a5d0888
@@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||
|
||||
def create_and_check_t5_model_fp16_forward(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||
def test_t5_model_fp16_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
||||
Reference in New Issue
Block a user