Fix other PyTorch models
This commit is contained in:
@@ -309,10 +309,12 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape)
|
attention_mask = torch.ones(input_shape, device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
|
|||||||
@@ -450,8 +450,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape) # (bs, seq_length)
|
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
|||||||
Reference in New Issue
Block a user