Generate: contrastive search with full optional outputs (#19963)
* Use beam search functionality; Add extra outputs and test * Add full tests for contrastive search * Add error message on unconventional cache format
This commit is contained in:
@@ -18,7 +18,7 @@ import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
@@ -35,7 +35,6 @@ if is_torch_available():
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
ImageGPTForCausalImageModeling,
|
||||
OPTForCausalLM,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
@@ -623,6 +622,76 @@ class GenerationTesterMixin:
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
def _contrastive_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
contrastive_search_kwargs = {
|
||||
"penalty_alpha": 0.6,
|
||||
"top_k": 5,
|
||||
}
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
|
||||
output_contrastive = model.contrastive_search(
|
||||
input_ids,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_processor=logits_processor,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
)
|
||||
return output_contrastive, output_generate
|
||||
|
||||
def test_greedy_generate(self):
|
||||
# check `generate()` and `greedy_search()` are equal
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -1336,6 +1405,64 @@ class GenerationTesterMixin:
|
||||
for output in (output_beam_search, output_generate):
|
||||
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
|
||||
|
||||
def test_contrastive_generate(self):
|
||||
# check `generate()` and `contrastive_search()` are equal
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
|
||||
return
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test old generation output for backwards compatibility
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_contrastive, output_generate = self._contrastive_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
|
||||
)
|
||||
self.assertListEqual(output_contrastive.tolist(), output_generate.tolist())
|
||||
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
|
||||
return
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_contrastive, output_generate = self._contrastive_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertListEqual(output_generate.sequences.tolist(), output_contrastive.sequences.tolist())
|
||||
|
||||
for output in (output_contrastive, output_generate):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
@@ -1696,197 +1823,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_bart(self):
|
||||
article = (
|
||||
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
|
||||
" year later, she got married again in Westchester County, but to a different man and without divorcing"
|
||||
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
|
||||
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
|
||||
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
|
||||
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
|
||||
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
|
||||
" license application, according to court documents. Prosecutors said the marriages were part of an"
|
||||
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
|
||||
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
|
||||
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
|
||||
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
|
||||
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
|
||||
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
|
||||
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
|
||||
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
|
||||
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
|
||||
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
|
||||
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
|
||||
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
|
||||
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
|
||||
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
|
||||
" up to four years in prison. Her next court appearance is scheduled for May 18."
|
||||
)
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
|
||||
input_ids = bart_tokenizer(
|
||||
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
|
||||
).input_ids.to(torch_device)
|
||||
|
||||
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
|
||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
|
||||
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
|
||||
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
|
||||
"to four years in"
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_t5(self):
|
||||
article = (
|
||||
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
|
||||
" year later, she got married again in Westchester County, but to a different man and without divorcing"
|
||||
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
|
||||
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
|
||||
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
|
||||
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
|
||||
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
|
||||
" license application, according to court documents. Prosecutors said the marriages were part of an"
|
||||
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
|
||||
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
|
||||
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
|
||||
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
|
||||
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
|
||||
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
|
||||
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
|
||||
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
|
||||
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
|
||||
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
|
||||
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
|
||||
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
|
||||
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
|
||||
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
|
||||
" up to four years in prison. Her next court appearance is scheduled for May 18."
|
||||
)
|
||||
article = "summarize: " + article.strip()
|
||||
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
|
||||
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
|
||||
input_ids = t5_tokenizer(
|
||||
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
|
||||
).input_ids.to(torch_device)
|
||||
|
||||
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
|
||||
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for "
|
||||
"permanent residence after the marriages, prosecutors say."
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_opt(self):
|
||||
article = (
|
||||
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
|
||||
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
|
||||
"there?"
|
||||
)
|
||||
|
||||
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
|
||||
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device)
|
||||
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
|
||||
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I "
|
||||
"am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have "
|
||||
"you lived there?\nStatue: A hundred years.\nHuman: And you’re from what country?\nStatue: The United "
|
||||
"States of America.\nHuman: Why did you come to America?\nStatue: I came to escape the tyranny of my "
|
||||
"country.\nHuman: What tyranny?\nStatue: They didn’t let me speak my mind.\nHuman: What was your "
|
||||
"country?\nStatue: It was a country of immigrants.\nHuman: Who were the immigrants?\nStatue: They "
|
||||
"were from all over the world.\nHuman: What language did they speak?\nStatue: French, Spanish, "
|
||||
"Italian, German, English—you name it.\nHuman: And where did they come from?\nStatue: They came from "
|
||||
"every country in the world.\nHuman: And you were born in what country?\nStatue: I was born in "
|
||||
"France.\nHuman: And your parents were French?\nStatue"
|
||||
],
|
||||
)
|
||||
|
||||
@tooslow
|
||||
def test_contrastive_search_gptj(self):
|
||||
article = (
|
||||
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and "
|
||||
"research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
|
||||
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
|
||||
"United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, "
|
||||
"Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google's "
|
||||
"parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating "
|
||||
"a company that would apply deep learning to problems in healthcare, energy, transportation, and "
|
||||
"other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 "
|
||||
"million in cash and stock.[3] The acquisition was seen as a way for Google to enter the "
|
||||
"fast-growing field of artificial intelligence (AI), which it had so far avoided due to concerns "
|
||||
'about ethical and social implications.[4] Google co-founder Sergey Brin said that he was "thrilled" '
|
||||
'to have acquired DeepMind, and that it would "help us push the boundaries of AI even further."'
|
||||
"[5]\n\nDeepMind's founders, Demis Hassabis and Mustafa Suleyman, were joined by a number of Google "
|
||||
"employees"
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_gpt2(self):
|
||||
article = (
|
||||
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
|
||||
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
|
||||
)
|
||||
|
||||
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
|
||||
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
|
||||
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
|
||||
|
||||
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
|
||||
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
|
||||
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
|
||||
"Google Now, which helps users find the information they're looking for on the web. But the company "
|
||||
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
|
||||
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
|
||||
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
|
||||
"concerned about the company's ability to keep users' information private. In a blog post last "
|
||||
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
|
||||
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
|
||||
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
|
||||
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
|
||||
"but said in a statement to The Associated Press that"
|
||||
],
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_greedy(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
@@ -3045,6 +2981,31 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
def test_contrastive_search_batched(self):
|
||||
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
|
||||
articles = ["Foo", "Bar Baz"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||
|
||||
model.config.eos_token_id = None
|
||||
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids = tokenizer(articles[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
output_sequences_batched = model.generate(
|
||||
input_ids=input_ids_batched, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
output_sequences = model.generate(
|
||||
input_ids=input_ids, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
|
||||
batched_out = tokenizer.decode(output_sequences_batched.sequences[1], skip_special_tokens=True)
|
||||
out = tokenizer.decode(output_sequences.sequences[0], skip_special_tokens=True)
|
||||
self.assertEqual(batched_out, out)
|
||||
|
||||
# output_sequences_batched.scores[0][1] -> 1st set of logits, 2nd sequence
|
||||
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
||||
self.assertTrue(max_score_diff < 1e-5)
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
Reference in New Issue
Block a user