Fix BERT
This commit is contained in:
@@ -170,7 +170,7 @@ class BertEmbeddings(nn.Module):
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@@ -655,11 +655,11 @@ class BertModel(BertPreTrainedModel):
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape)
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(input_shape)
|
||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
|
||||
Reference in New Issue
Block a user