This commit is contained in:
Patrick von Platen
2021-01-21 11:17:13 +01:00
committed by GitHub
parent c8ea582ed6
commit ca422e3d7d

View File

@@ -934,9 +934,9 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if head_mask is not None:
if not (isinstance(head_mask, list) and head_mask[0] is None):
head_mask = head_mask.to(hidden_states.device)
if encoder_head_mask is not None:
if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None):
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)