Remove device parameter from create_extended_attention_mask_for_decoder (#16894)

This commit is contained in:
Pavel Belevich
2022-05-03 11:06:11 -04:00
committed by GitHub
parent dd739f7045
commit 39f8eafc1b
31 changed files with 48 additions and 42 deletions

View File

@@ -137,7 +137,7 @@ class RetrievalQAEmbedder(nn.Module):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
)
# define function for checkpointing