updated all tests

This commit is contained in:
patrickvonplaten
2020-03-08 15:29:10 +01:00
parent e03129ad44
commit 575976144a
13 changed files with 1465 additions and 112 deletions

View File

@@ -223,7 +223,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
)
# get two different outputs
@@ -343,39 +343,36 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
input_ids = torch.tensor([[463, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
464,
3290,
318,
13779,
1165,
13,
632,
7832,
284,
6437,
319,
502,
373,
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
318,
922,
329,
502,
357,
1169,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog is cute too. It likes to rub on me and is good for me (the dog
torch.manual_seed(0)
output_ids = model.generate(input_ids)
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_lm_generate_distilgpt2(self):
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
input_ids = torch.Tensor([[464, 1893]]).long() # The president
input_ids = torch.tensor([[463, 1893]], dtype=torch.long, device=torch_device) # The president
expected_output_ids = [
464,
1893,