This commit is contained in:
Julien Chaumond
2019-11-06 16:21:00 +00:00
parent 27e015bd54
commit d5319793c4

View File

@@ -170,7 +170,7 @@ class BertEmbeddings(nn.Module):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape) position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
@@ -655,11 +655,11 @@ class BertModel(BertPreTrainedModel):
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape) attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape) encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.