diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 3a8a9c541a..21fc873234 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -205,7 +205,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): model.eval() # create attention mask - attn_mask = torch.ones(input_ids.shape).long() + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) half_seq_length = self.seq_length // 2 attn_mask[:, half_seq_length:] = 0 @@ -222,7 +222,9 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): # append to next input_ids and attn_mask next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1 + ) # get two different outputs output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)