Remove device parameter from create_extended_attention_mask_for_decoder (#16894)
This commit is contained in:
@@ -137,7 +137,7 @@ class RetrievalQAEmbedder(nn.Module):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
|
||||
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
|
||||
attention_mask, input_shape, device
|
||||
attention_mask, input_shape
|
||||
)
|
||||
|
||||
# define function for checkpointing
|
||||
|
||||
Reference in New Issue
Block a user