PegasusForConditionalGeneration (torch version) (#6340)
Co-authored-by: Jingqing Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
@@ -45,6 +45,7 @@ if is_torch_available():
|
||||
_prepare_bart_decoder_inputs,
|
||||
SinusoidalPositionalEmbedding,
|
||||
)
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -478,7 +479,6 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
self.assertFalse(model.config.is_valid_mbart())
|
||||
tok = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
|
||||
dct = tok.batch_encode_plus(
|
||||
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
|
||||
|
||||
@@ -23,14 +23,13 @@ RO_CODE = 250020
|
||||
|
||||
|
||||
@require_torch
|
||||
class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
|
||||
maxDiff = 1000 # longer string compare tracebacks
|
||||
checkpoint_name = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
@@ -43,7 +42,7 @@ class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
@@ -73,7 +72,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
]
|
||||
),
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(1)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
@@ -125,7 +124,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||
class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-cc25"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
|
||||
79
tests/test_modeling_pegasus.py
Normal file
79
tests/test_modeling_pegasus.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bart import PGE_ARTICLE
|
||||
from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """
|
||||
|
||||
|
||||
@require_torch
|
||||
class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "google/pegasus-xsum"
|
||||
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
|
||||
tgt_text = [
|
||||
"California's largest electricity provider has turned off power to tens of thousands of customers.",
|
||||
"N-Dubz have revealed they weren't expecting to get four nominations at this year's Mobo Awards.",
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
|
||||
|
||||
@slow
|
||||
def test_pegasus_xsum_summary(self):
|
||||
assert self.tokenizer.model_max_length == 512
|
||||
inputs = self.tokenizer(self.src_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(
|
||||
torch_device
|
||||
)
|
||||
assert inputs.input_ids.shape == (2, 421)
|
||||
translated_tokens = self.model.generate(**inputs)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text, decoded)
|
||||
|
||||
if "cuda" not in torch_device:
|
||||
return
|
||||
# Demonstrate fp16 issue, Contributions welcome!
|
||||
self.model.half()
|
||||
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
|
||||
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"]
|
||||
self.assertListEqual(decoded, bad_fp16_result)
|
||||
|
||||
|
||||
class PegasusConfigTests(unittest.TestCase):
|
||||
def test_all_config_max_lengths(self):
|
||||
expected_max_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
}
|
||||
failures = []
|
||||
pegasus_prefix = "google/pegasus"
|
||||
for dataset, max_len in expected_max_length.items():
|
||||
mname = f"{pegasus_prefix}-{dataset}"
|
||||
cfg = AutoConfig.from_pretrained(mname)
|
||||
if cfg.max_length != max_len:
|
||||
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
|
||||
if failures == []:
|
||||
return
|
||||
# error
|
||||
all_fails = "\n".join(failures)
|
||||
raise AssertionError(f"The following configs have unexpected settings: {all_fails}")
|
||||
69
tests/test_tokenization_pegasus.py
Normal file
69
tests/test_tokenization_pegasus.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.tokenization_pegasus import PegasusTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = PegasusTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
save_dir = Path(self.tmpdirname)
|
||||
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
|
||||
if not (save_dir / spm_file).exists():
|
||||
tokenizer = self.pegasus_large_tokenizer
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
@cached_property
|
||||
def pegasus_large_tokenizer(self):
|
||||
return PegasusTokenizer.from_pretrained("google/pegasus-large")
|
||||
|
||||
@unittest.skip("add_tokens does not work yet")
|
||||
def test_swap_special_token(self):
|
||||
pass
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
||||
if not kwargs:
|
||||
return self.pegasus_large_tokenizer
|
||||
else:
|
||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return ("This is a test", "This is a test")
|
||||
|
||||
def test_pegasus_large_tokenizer_settings(self):
|
||||
tokenizer = self.pegasus_large_tokenizer
|
||||
# The tracebacks for the following asserts are **better** without messages or self.assertEqual
|
||||
assert tokenizer.vocab_size == 96103
|
||||
assert tokenizer.pad_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
assert tokenizer.offset == 103
|
||||
assert tokenizer.unk_token_id == tokenizer.offset + 2 == 105
|
||||
assert tokenizer.unk_token == "<unk>"
|
||||
assert tokenizer.mask_token is None
|
||||
assert tokenizer.mask_token_id is None
|
||||
assert tokenizer.model_max_length == 1024
|
||||
raw_input_str = "To ensure a smooth flow of bank resolutions."
|
||||
desired_result = [413, 615, 114, 2291, 1971, 113, 1679, 10710, 107, 1]
|
||||
ids = tokenizer([raw_input_str], return_tensors=None).input_ids[0]
|
||||
self.assertListEqual(desired_result, ids)
|
||||
assert tokenizer.convert_ids_to_tokens([0, 1, 2]) == ["<pad>", "</s>", "unk_2"]
|
||||
|
||||
@require_torch
|
||||
def test_pegasus_large_seq2seq_truncation(self):
|
||||
src_texts = ["This is going to be way too long" * 10000, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "decoder_input_ids" in batch # because tgt_texts was specified
|
||||
assert batch.decoder_input_ids.shape == (2, 5)
|
||||
assert batch.decoder_attention_mask.shape == (2, 5)
|
||||
assert len(batch) == 4 # no extra keys
|
||||
Reference in New Issue
Block a user