[T5 fp16] Fix fp16 in T5 (#4436)

* fix fp16 in t5

* make style

* refactor invert_attention_mask fn

* fix typo
This commit is contained in:
Patrick von Platen
2020-05-18 17:25:58 +02:00
committed by GitHub
parent fa6113f9a0
commit 026a5d0888
3 changed files with 36 additions and 3 deletions

View File

@@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_t5_model_fp16_forward(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
model.half()
model.eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_t5_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: