fix typo in test gpt2

This commit is contained in:
patrickvonplaten
2020-03-08 15:35:08 +01:00
parent 314bdc7c14
commit 66c827656f

View File

@@ -343,7 +343,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_gpt2(self): def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2")
input_ids = torch.tensor([[463, 3290]], dtype=torch.long, device=torch_device) # The dog input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [ expected_output_ids = [
464, 464,
3290, 3290,