Fix other PyTorch models
This commit is contained in:
@@ -450,8 +450,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape) # (bs, seq_length)
|
||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
||||
Reference in New Issue
Block a user