fix typo in test
This commit is contained in:
@@ -372,7 +372,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_lm_generate_distilgpt2(self):
|
def test_lm_generate_distilgpt2(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
input_ids = torch.tensor([[463, 1893]], dtype=torch.long, device=torch_device) # The president
|
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
464,
|
464,
|
||||||
1893,
|
1893,
|
||||||
|
|||||||
Reference in New Issue
Block a user