TF: add beam search tests (#16202)
This commit is contained in:
@@ -548,6 +548,29 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@slow
|
||||
def test_beam_search_generate(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
|
||||
"no_repeat_ngram_size": 3,
|
||||
"do_sample": False,
|
||||
"repetition_penalty": 2.2,
|
||||
"num_beams": 4,
|
||||
}
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user