From 026a5d088871d983f45cba353d1fa0e9ab068217 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 18 May 2020 17:25:58 +0200 Subject: [PATCH] [T5 fp16] Fix fp16 in T5 (#4436) * fix fp16 in t5 * make style * refactor invert_attention_mask fn * fix typo --- src/transformers/modeling_t5.py | 11 +++++++++-- src/transformers/modeling_utils.py | 13 ++++++++++++- tests/test_modeling_t5.py | 15 +++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index d8d61e0f02..6fcc9453d2 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module): self.variance_epsilon = eps def forward(self, x): - variance = x.pow(2).mean(-1, keepdim=True) + # layer norm should always be calculated in float32 + variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x = x / torch.sqrt(variance + self.variance_epsilon) + + if self.weight.dtype == torch.float16: + x = x.to(torch.float16) return self.weight * x @@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel): attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device) + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) # initialize past_key_value_states with `None` if past does not exist if past_key_value_states is None: @@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel): # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] + if i == 0: # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cd98907a94..80b51b2912 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -128,7 +128,18 @@ class ModuleUtilsMixin: # encoder_extended_attention_mask = (encoder_extended_attention_mask == # encoder_extended_attention_mask.transpose(-1, -2)) encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + + if self.dtype == torch.float16: + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 + elif self.dtype == torch.float32: + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + else: + raise ValueError( + "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( + self.dtype + ) + ) + return encoder_extended_attention_mask def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device): diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 5209719b59..f8e3114ad6 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + def create_and_check_t5_model_fp16_forward( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, + ): + model = T5Model(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0] + self.parent.assertFalse(torch.isnan(output).any().item()) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_t5_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: