From 2f3a4210185f5311f6cfab3c91b30616c9a30fc8 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 6 Nov 2019 16:43:09 +0000 Subject: [PATCH] Fix other PyTorch models --- templates/adding_a_new_model/modeling_xxx.py | 6 ++++-- transformers/modeling_distilbert.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/templates/adding_a_new_model/modeling_xxx.py b/templates/adding_a_new_model/modeling_xxx.py index d023f565f5..1f98c6406f 100644 --- a/templates/adding_a_new_model/modeling_xxx.py +++ b/templates/adding_a_new_model/modeling_xxx.py @@ -309,10 +309,12 @@ class XxxModel(XxxPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones(input_shape) + attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] diff --git a/transformers/modeling_distilbert.py b/transformers/modeling_distilbert.py index aca1670852..00106627a8 100644 --- a/transformers/modeling_distilbert.py +++ b/transformers/modeling_distilbert.py @@ -450,8 +450,10 @@ class DistilBertModel(DistilBertPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones(input_shape) # (bs, seq_length) + attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head