Add TF generate sample tests with all logit processors (#15852)

* Add GPT2 TF generate sample test with all logits processor

* Add T5 generate sample test
This commit is contained in:
Joao Gante
2022-03-02 09:48:11 +00:00
committed by GitHub
parent 40040727ab
commit 8a133490bf
2 changed files with 59 additions and 1 deletions

View File

@@ -480,6 +480,33 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
def test_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"do_sample": True,
"bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
"no_repeat_ngram_size": 3,
"repetition_penalty": 2.2,
"temperature": 0.8,
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings)
@require_tf
@require_sentencepiece