From b370cc7e99c5b8c7436154d4694c33b461ea0f08 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 26 Feb 2020 21:48:49 +0000 Subject: [PATCH] [gpu] Fixup fdd61b19928e87a5354c36923182e801bfedb31b --- tests/test_modeling_gpt2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 3a8a9c541a..21fc873234 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -205,7 +205,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): model.eval() # create attention mask - attn_mask = torch.ones(input_ids.shape).long() + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) half_seq_length = self.seq_length // 2 attn_mask[:, half_seq_length:] = 0 @@ -222,7 +222,9 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): # append to next input_ids and attn_mask next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1 + ) # get two different outputs output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)