TF: GPT-2 generation supports left-padding (#17426)
* TF GPT-2 now properly works with left padding * throw a warning when eos token == pad token and there is no attention mask
This commit is contained in:
@@ -456,7 +456,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
@@ -465,12 +465,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"repetition_penalty": 1.3,
|
||||
}
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and I am so happy to be able take part in this amazing event.",
|
||||
"Yesterday was a very busy day for the first time since I started writing this post",
|
||||
"Yesterday was a very interesting time for the world to see how much of this is",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@@ -483,7 +483,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
@@ -498,13 +498,13 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and we will make you feel very hot/terrific in all",
|
||||
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
|
||||
"Today is a beautiful day and we will make you feel very hot/terrific in all your",
|
||||
"Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@@ -517,7 +517,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
@@ -526,37 +526,69 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"num_beams": 2,
|
||||
}
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and a great day for all of us.\n\nI’m",
|
||||
"Yesterday was the first day of the year for the second time in a row,",
|
||||
"Yesterday was the first time that a person has been arrested in the United States for",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2_left_padding(self):
|
||||
"""Tests that the generated text is the same, regarless of left padding"""
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"do_sample": False,
|
||||
"repetition_penalty": 1.3,
|
||||
}
|
||||
expected_output_string = (
|
||||
"Today is a beautiful day and I am so happy to be able take part in this amazing event."
|
||||
)
|
||||
|
||||
sentences = ["Today is a beautiful day and"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
# using default length
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertEqual(output_strings[0], expected_output_string)
|
||||
|
||||
sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
# longer max length to capture the full length (remember: it is left padded)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertEqual(output_strings[0], expected_output_string)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_greedy_xla(self):
|
||||
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
|
||||
# the underlying problem)
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["The dog"]
|
||||
sentences = ["The dog", "The flying machine"]
|
||||
expected_output_strings = [
|
||||
"The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
|
||||
"The dog was found in a field near the intersection of West and West Streets.\n\nThe",
|
||||
"The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
|
||||
]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
output_ids = model.generate(**input_ids, do_sample=False)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_strings)
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
output_ids = xla_generate(input_ids, do_sample=False)
|
||||
output_ids = xla_generate(**input_ids, do_sample=False)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_strings)
|
||||
|
||||
@@ -574,21 +606,24 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentence = ["The dog"]
|
||||
sentence = ["The dog", "The flying machine"]
|
||||
expected_output_string = [
|
||||
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
|
||||
" puppies"
|
||||
" puppies",
|
||||
"The flying machine was made by an artist who found it difficult to control it as it did not use",
|
||||
]
|
||||
expected_output_string_xla = [
|
||||
"The dog has been named in connection with the murder of a 20-year-old man in!"
|
||||
"The dog has been named in connection with the murder of a 20-year-old man in",
|
||||
"The flying machine is a new and improved system to operate and operate a new system and system "
|
||||
"system system",
|
||||
]
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string_xla)
|
||||
|
||||
Reference in New Issue
Block a user