fix typo in test

This commit is contained in:
patrickvonplaten
2020-03-08 15:34:20 +01:00
parent 575976144a
commit 314bdc7c14

View File

@@ -372,7 +372,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_distilgpt2(self):
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
input_ids = torch.tensor([[463, 1893]], dtype=torch.long, device=torch_device) # The president
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
expected_output_ids = [
464,
1893,