[gpu] Fixup fdd61b1992
This commit is contained in:
@@ -205,7 +205,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# create attention mask
|
# create attention mask
|
||||||
attn_mask = torch.ones(input_ids.shape).long()
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
half_seq_length = self.seq_length // 2
|
half_seq_length = self.seq_length // 2
|
||||||
attn_mask[:, half_seq_length:] = 0
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
@@ -222,7 +222,9 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# append to next input_ids and attn_mask
|
# append to next input_ids and attn_mask
|
||||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1)
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
# get two different outputs
|
# get two different outputs
|
||||||
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
||||||
|
|||||||
Reference in New Issue
Block a user