From 248fa1ae72f01460c76c2450d747fa1cc29d85b0 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 21 Jan 2021 16:46:14 +0530 Subject: [PATCH] fix T5 head mask in model_parallel (#9726) * fix head mask in model_parallel * pass correct head mask --- src/transformers/models/t5/modeling_t5.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f27ede05e1..1eb35a2875 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel): hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + encoder_layer_head_mask = encoder_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) @@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel): encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if not (isinstance(head_mask, list) and head_mask[0] is None): - head_mask = head_mask.to(hidden_states.device) - if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None): - encoder_head_mask = encoder_head_mask.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if encoder_layer_head_mask is not None: + encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=head_mask[i], - encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None, + layer_head_mask=layer_head_mask, + encoder_layer_head_mask=encoder_layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions,