fix typo in test gpt2
This commit is contained in:
@@ -343,7 +343,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_lm_generate_gpt2(self):
|
def test_lm_generate_gpt2(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
input_ids = torch.tensor([[463, 3290]], dtype=torch.long, device=torch_device) # The dog
|
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
464,
|
464,
|
||||||
3290,
|
3290,
|
||||||
|
|||||||
Reference in New Issue
Block a user