From 8a133490bf185f707b8155a801a438622c816316 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 2 Mar 2022 09:48:11 +0000 Subject: [PATCH] Add TF generate sample tests with all logit processors (#15852) * Add GPT2 TF generate sample test with all logits processor * Add T5 generate sample test --- tests/gpt2/test_modeling_tf_gpt2.py | 33 ++++++++++++++++++++++++++++- tests/t5/test_modeling_tf_t5.py | 27 +++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/tests/gpt2/test_modeling_tf_gpt2.py b/tests/gpt2/test_modeling_tf_gpt2.py index 27d30a630a..0952ff8f70 100644 --- a/tests/gpt2/test_modeling_tf_gpt2.py +++ b/tests/gpt2/test_modeling_tf_gpt2.py @@ -442,7 +442,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) @slow - def test_lm_generate_distilgpt2_batch_special(self): + def test_lm_generate_greedy_distilgpt2_batch_special(self): model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") @@ -468,6 +468,37 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ] self.assertListEqual(output_strings, expected_output_string) + @slow + def test_lm_generate_sample_distilgpt2_batch_special(self): + model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") + tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") + + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + sentences = ["Today is a beautiful day and", "Yesterday was"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "do_sample": True, + "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], + "no_repeat_ngram_size": 2, + "repetition_penalty": 1.3, + "temperature": 1.5, + "top_k": 500, + "top_p": 0.9, + } + tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation + + output_ids = model.generate(input_ids, **generation_kwargs) + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + expected_output_string = [ + "Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh", + "Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say", + ] + self.assertListEqual(output_strings, expected_output_string) + @slow def test_lm_generate_gpt2(self): model = TFGPT2LMHeadModel.from_pretrained("gpt2") diff --git a/tests/t5/test_modeling_tf_t5.py b/tests/t5/test_modeling_tf_t5.py index d83e30bc16..77d0f67468 100644 --- a/tests/t5/test_modeling_tf_t5.py +++ b/tests/t5/test_modeling_tf_t5.py @@ -480,6 +480,33 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): self.assertListEqual(expected_output_string, output_strings) + @slow + def test_sample_generate(self): + model = TFT5ForConditionalGeneration.from_pretrained("t5-small") + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "do_sample": True, + "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], + "no_repeat_ngram_size": 3, + "repetition_penalty": 2.2, + "temperature": 0.8, + "top_k": 500, + "top_p": 0.9, + } + tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation + + output_ids = model.generate(input_ids, **generation_kwargs) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"] + + self.assertListEqual(expected_output_string, output_strings) + @require_tf @require_sentencepiece