Generate: contrastive search with full optional outputs (#19963)
* Use beam search functionality; Add extra outputs and test * Add full tests for contrastive search * Add error message on unconventional cache format
This commit is contained in:
@@ -30,7 +30,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import ByT5Tokenizer, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
ByT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5Model,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -1216,6 +1223,51 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_t5(self):
|
||||
article = (
|
||||
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
|
||||
" year later, she got married again in Westchester County, but to a different man and without divorcing"
|
||||
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
|
||||
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
|
||||
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
|
||||
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
|
||||
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
|
||||
" license application, according to court documents. Prosecutors said the marriages were part of an"
|
||||
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
|
||||
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
|
||||
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
|
||||
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
|
||||
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
|
||||
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
|
||||
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
|
||||
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
|
||||
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
|
||||
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
|
||||
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
|
||||
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
|
||||
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
|
||||
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
|
||||
" up to four years in prison. Her next court appearance is scheduled for May 18."
|
||||
)
|
||||
article = "summarize: " + article.strip()
|
||||
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
|
||||
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
|
||||
input_ids = t5_tokenizer(
|
||||
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
|
||||
).input_ids.to(torch_device)
|
||||
|
||||
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
|
||||
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for "
|
||||
"permanent residence after the marriages, prosecutors say."
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAsymmetricT5(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user