Fix other PyTorch models

This commit is contained in:
Julien Chaumond
2019-11-06 16:43:09 +00:00
parent d5319793c4
commit 2f3a421018
2 changed files with 7 additions and 3 deletions

View File

@@ -450,8 +450,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape) # (bs, seq_length)
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head