Black 20 release
This commit is contained in:
@@ -244,7 +244,8 @@ class GPT2ModelTester:
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
|
||||
Reference in New Issue
Block a user