[Seq2Seq] Fix a couple of bugs and clean examples (#7474)
* clean T5 * fix t5 tests * fix index typo * fix tf common test * fix examples * change positional ordering for Bart and FSTM * add signature test * clean docs and add tests * add docs to encoder decoder * clean docs * correct two doc strings * remove sig test for TF Elektra & Funnel * fix tf t5 slow tests * fix input_ids to inputs in tf * Update src/transformers/modeling_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * implement lysandre results * make style * fix encoder decoder typo * fix tf slow tests * fix slow tests * renaming * remove unused input Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a42f62d34f
commit
62f5ae68ec
@@ -235,7 +235,7 @@ class T5ModelTester:
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs.to_tuple()
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -244,7 +244,7 @@ class T5ModelTester:
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
@@ -274,7 +274,7 @@ class T5ModelTester:
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
|
||||
output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -293,7 +293,7 @@ class T5ModelTester:
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
@@ -305,7 +305,41 @@ class T5ModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_generate_with_past_key_value_states(
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder().to(torch_device).eval()
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_generate_with_past_key_values(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
@@ -439,7 +473,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = False
|
||||
is_encoder_decoder = True
|
||||
|
||||
@@ -470,9 +504,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_generate_with_past_key_value_states(self):
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_generate_with_past_key_values(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs)
|
||||
|
||||
def test_encoder_decoder_shared_weights(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -495,10 +533,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
config_and_inputs[1],
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/t5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
|
||||
@@ -527,7 +566,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
|
||||
|
||||
expected_summaries = [
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," the magazine says .',
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .',
|
||||
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
|
||||
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
|
||||
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
|
||||
@@ -604,13 +643,6 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
"sous forme "
|
||||
"de points bleus."
|
||||
)
|
||||
# expected_translation = (
|
||||
# "Cette section d'images provenant de l'enregistrement infrarouge effectué par le "
|
||||
# "télescope Spitzer montre un « portrait familial » de générations innombrables de "
|
||||
# "étoiles : les plus anciennes sont observées sous forme de pointes bleues, "
|
||||
# "alors que les « nouveau-nés » de couleur rose dans la salle des accouchements doivent "
|
||||
# "être plus difficiles "
|
||||
# )
|
||||
|
||||
self.assertEqual(translation, new_truncated_translation)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user