fix T5 head mask in model_parallel (#9726)
* fix head mask in model_parallel * pass correct head mask
This commit is contained in:
@@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
hidden_states = self.dropout(inputs_embeds)
|
hidden_states = self.dropout(inputs_embeds)
|
||||||
|
|
||||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
||||||
|
layer_head_mask = head_mask[i]
|
||||||
|
encoder_layer_head_mask = encoder_head_mask[i]
|
||||||
# Model parallel
|
# Model parallel
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
torch.cuda.set_device(hidden_states.device)
|
torch.cuda.set_device(hidden_states.device)
|
||||||
@@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
|
||||||
if encoder_decoder_position_bias is not None:
|
if encoder_decoder_position_bias is not None:
|
||||||
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
||||||
if not (isinstance(head_mask, list) and head_mask[0] is None):
|
if layer_head_mask is not None:
|
||||||
head_mask = head_mask.to(hidden_states.device)
|
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
||||||
if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None):
|
if encoder_layer_head_mask is not None:
|
||||||
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
|
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||||
layer_head_mask=head_mask[i],
|
layer_head_mask=layer_head_mask,
|
||||||
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None,
|
encoder_layer_head_mask=encoder_layer_head_mask,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|||||||
Reference in New Issue
Block a user