Generate: contrastive search with full optional outputs (#19963)

* Use beam search functionality; Add extra outputs and test

* Add full tests for contrastive search

* Add error message on unconventional cache format
This commit is contained in:
Joao Gante
2022-11-01 18:15:36 +00:00
committed by GitHub
parent ab74ac11e4
commit 831590f6a9
8 changed files with 499 additions and 300 deletions

View File

@@ -1181,6 +1181,52 @@ class BartModelIntegrationTests(unittest.TestCase):
)
assert generated_summaries == EXPECTED
@slow
def test_contrastive_search_bart(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
class BartStandaloneDecoderModelTester:
def __init__(

View File

@@ -763,3 +763,37 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
@slow
def test_contrastive_search_gpt2(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)

View File

@@ -572,3 +572,38 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
@tooslow
def test_contrastive_search_gptj(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and "
"research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, "
"Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google's "
"parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating "
"a company that would apply deep learning to problems in healthcare, energy, transportation, and "
"other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 "
"million in cash and stock.[3] The acquisition was seen as a way for Google to enter the "
"fast-growing field of artificial intelligence (AI), which it had so far avoided due to concerns "
'about ethical and social implications.[4] Google co-founder Sergey Brin said that he was "thrilled" '
'to have acquired DeepMind, and that it would "help us push the boundaries of AI even further."'
"[5]\n\nDeepMind's founders, Demis Hassabis and Mustafa Suleyman, were joined by a number of Google "
"employees"
],
)

View File

@@ -490,3 +490,34 @@ class OPTGenerationTest(unittest.TestCase):
self.assertFalse(
torch.isnan(outputs.logits[0]).any().item()
) # the first logits could contain NaNs if it fails
@slow
def test_contrastive_search_opt(self):
article = (
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
"there?"
)
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I "
"am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have "
"you lived there?\nStatue: A hundred years.\nHuman: And youre from what country?\nStatue: The United "
"States of America.\nHuman: Why did you come to America?\nStatue: I came to escape the tyranny of my "
"country.\nHuman: What tyranny?\nStatue: They didnt let me speak my mind.\nHuman: What was your "
"country?\nStatue: It was a country of immigrants.\nHuman: Who were the immigrants?\nStatue: They "
"were from all over the world.\nHuman: What language did they speak?\nStatue: French, Spanish, "
"Italian, German, English—you name it.\nHuman: And where did they come from?\nStatue: They came from "
"every country in the world.\nHuman: And you were born in what country?\nStatue: I was born in "
"France.\nHuman: And your parents were French?\nStatue"
],
)

View File

@@ -30,7 +30,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from transformers import ByT5Tokenizer, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer
from transformers import (
AutoTokenizer,
ByT5Tokenizer,
T5EncoderModel,
T5ForConditionalGeneration,
T5Model,
T5Tokenizer,
)
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -1216,6 +1223,51 @@ class T5ModelIntegrationTests(unittest.TestCase):
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(translation, expected_translation)
@slow
def test_contrastive_search_t5(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
article = "summarize: " + article.strip()
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
input_ids = t5_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for "
"permanent residence after the marriages, prosecutors say."
],
)
@require_torch
class TestAsymmetricT5(unittest.TestCase):