Add TF generate sample tests with all logit processors (#15852)

* Add GPT2 TF generate sample test with all logits processor

* Add T5 generate sample test
This commit is contained in:
Joao Gante
2022-03-02 09:48:11 +00:00
committed by GitHub
parent 40040727ab
commit 8a133490bf
2 changed files with 59 additions and 1 deletions

View File

@@ -442,7 +442,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_distilgpt2_batch_special(self):
def test_lm_generate_greedy_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
@@ -468,6 +468,37 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
]
self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_sample_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"do_sample": True,
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.3,
"temperature": 1.5,
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
]
self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_gpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")

View File

@@ -480,6 +480,33 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
def test_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"do_sample": True,
"bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
"no_repeat_ngram_size": 3,
"repetition_penalty": 2.2,
"temperature": 0.8,
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings)
@require_tf
@require_sentencepiece