Remove device parameter from create_extended_attention_mask_for_decoder (#16894)

This commit is contained in:
Pavel Belevich
2022-05-03 11:06:11 -04:00
committed by GitHub
parent dd739f7045
commit 39f8eafc1b
31 changed files with 48 additions and 42 deletions

View File

@@ -876,7 +876,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]