updated all tests

This commit is contained in:
patrickvonplaten
2020-03-08 15:29:10 +01:00
parent e03129ad44
commit 575976144a
13 changed files with 1465 additions and 112 deletions

View File

@@ -328,13 +328,35 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 50256, "eos_token_id": 50256}
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow
def test_lm_generate_gpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
expected_output_ids = [
464,
3290,
373,
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_lm_generate_distilgpt2(self):
@@ -363,11 +385,5 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
2635,
] # The president of the United States, and the president of the United Kingdom, have been in the White
output_ids = model.generate(
input_ids,
do_sample=False,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)