Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py (#19477)

* add: the contrastive search for generaton_utils

* add: testing scripts for contrastive search under examples/text-generation

* update the quality of codes

* revise the docstring; make the generation_contrastive_search.py scripts;

* revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format

* revise the necessary documents

* fix: revise the docstring of generation_contrastive_search.py

* Fix the code indentation

* fix: revise the nits and examples in contrastive_search docstring.

* fix the copyright

* delete generation_contrastive_search.py

* revise the logic in contrastive_search

* update the intergration test and the docstring

* run the tests over

* add the slow decorate to the contrastive_search intergrate test

* add more test

* do the style, quality, consistency checks
This commit is contained in:
GMFTBY
2022-10-19 17:17:46 +08:00
committed by GitHub
parent fc5fdc109d
commit 71786b10c5
5 changed files with 853 additions and 5 deletions

View File

@@ -27,6 +27,7 @@ if is_torch_available():
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration,
@@ -34,8 +35,10 @@ if is_torch_available():
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
OPTForCausalLM,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
VisionEncoderDecoderModel,
pipeline,
top_k_top_p_filtering,
@@ -1693,6 +1696,140 @@ class GenerationIntegrationTests(unittest.TestCase):
],
)
@slow
def test_contrastive_search_bart(self):
article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.
"""
bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002."""
],
)
@slow
def test_contrastive_search_t5(self):
article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.
"""
article = "summarize: " + article.strip()
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
input_ids = t5_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say."""
],
)
@slow
def test_contrastive_search_opt(self):
article = r"""A chat between a curious human and the Statue of Liberty.
Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?"""
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-6.7b")
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: Whats the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf"""
],
)
@slow
def test_contrastive_search_gptj(self):
article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"""
opt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
opt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C"""
],
)
@slow
def test_contrastive_search_gpt2(self):
article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"""
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that"""
],
)
def test_max_length_backward_compat_greedy(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
@@ -2050,6 +2187,134 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device)
input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 56])
max_new_tokens = 3
t5_model.config.max_length = 20
t5_model.config.eos_token_id = None
# Encoder decoder call
outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 56 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 59])
# Encoder decoder call > 20
outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
article = """Justin Timberlake."""
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device)
input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gptj_model.config.max_length = 20
# call < 20
outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gpt2_model.config.max_length = 20
# call < 20
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")