[T5 fp16] Fix fp16 in T5 (#4436)
* fix fp16 in t5 * make style * refactor invert_attention_mask fn * fix typo
This commit is contained in:
committed by
GitHub
parent
fa6113f9a0
commit
026a5d0888
@@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module):
|
|||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
variance = x.pow(2).mean(-1, keepdim=True)
|
# layer norm should always be calculated in float32
|
||||||
|
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
x = x / torch.sqrt(variance + self.variance_epsilon)
|
x = x / torch.sqrt(variance + self.variance_epsilon)
|
||||||
|
|
||||||
|
if self.weight.dtype == torch.float16:
|
||||||
|
x = x.to(torch.float16)
|
||||||
return self.weight * x
|
return self.weight * x
|
||||||
|
|
||||||
|
|
||||||
@@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
||||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||||
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device)
|
encoder_attention_mask = torch.ones(
|
||||||
|
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
# initialize past_key_value_states with `None` if past does not exist
|
# initialize past_key_value_states with `None` if past does not exist
|
||||||
if past_key_value_states is None:
|
if past_key_value_states is None:
|
||||||
@@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
# layer_outputs is a tuple with:
|
# layer_outputs is a tuple with:
|
||||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# We share the position biases between the layers - the first layer store them
|
# We share the position biases between the layers - the first layer store them
|
||||||
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||||
|
|||||||
@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
|
|||||||
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
||||||
# encoder_extended_attention_mask.transpose(-1, -2))
|
# encoder_extended_attention_mask.transpose(-1, -2))
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
|
||||||
|
if self.dtype == torch.float16:
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
|
||||||
|
elif self.dtype == torch.float32:
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
|
||||||
|
self.dtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return encoder_extended_attention_mask
|
return encoder_extended_attention_mask
|
||||||
|
|
||||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
||||||
|
|||||||
@@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||||
|
|
||||||
|
def create_and_check_t5_model_fp16_forward(
|
||||||
|
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||||
|
):
|
||||||
|
model = T5Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||||
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
|
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||||
|
def test_t5_model_fp16_forward(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user