Add TF generate sample tests with all logit processors (#15852)
* Add GPT2 TF generate sample test with all logits processor * Add T5 generate sample test
This commit is contained in:
@@ -442,7 +442,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_distilgpt2_batch_special(self):
|
def test_lm_generate_greedy_distilgpt2_batch_special(self):
|
||||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||||
|
|
||||||
@@ -468,6 +468,37 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(output_strings, expected_output_string)
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_sample_distilgpt2_batch_special(self):
|
||||||
|
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||||
|
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||||
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"do_sample": True,
|
||||||
|
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||||
|
"no_repeat_ngram_size": 2,
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"temperature": 1.5,
|
||||||
|
"top_k": 500,
|
||||||
|
"top_p": 0.9,
|
||||||
|
}
|
||||||
|
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||||
|
|
||||||
|
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||||
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_output_string = [
|
||||||
|
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
|
||||||
|
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
|
||||||
|
]
|
||||||
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_gpt2(self):
|
def test_lm_generate_gpt2(self):
|
||||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
|||||||
@@ -480,6 +480,33 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(expected_output_string, output_strings)
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_sample_generate(self):
|
||||||
|
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
|
||||||
|
sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
|
||||||
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"do_sample": True,
|
||||||
|
"bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
|
||||||
|
"no_repeat_ngram_size": 3,
|
||||||
|
"repetition_penalty": 2.2,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_k": 500,
|
||||||
|
"top_p": 0.9,
|
||||||
|
}
|
||||||
|
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||||
|
|
||||||
|
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||||
|
|
||||||
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
|
||||||
|
|
||||||
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
Reference in New Issue
Block a user