updated all tests
This commit is contained in:
@@ -223,7 +223,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
@@ -343,39 +343,36 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
|
||||
input_ids = torch.tensor([[463, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
318,
|
||||
13779,
|
||||
1165,
|
||||
13,
|
||||
632,
|
||||
7832,
|
||||
284,
|
||||
6437,
|
||||
319,
|
||||
502,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
318,
|
||||
922,
|
||||
329,
|
||||
502,
|
||||
357,
|
||||
1169,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog is cute too. It likes to rub on me and is good for me (the dog
|
||||
torch.manual_seed(0)
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = torch.Tensor([[464, 1893]]).long() # The president
|
||||
input_ids = torch.tensor([[463, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
|
||||
Reference in New Issue
Block a user