From 0f58903bb62870342eae52f5a02c9105ec6f9b1e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 29 Aug 2020 17:43:32 -0400 Subject: [PATCH] Pegasus finetune script: add --adafactor (#6811) --- examples/seq2seq/finetune_pegasus_xsum.sh | 2 +- src/transformers/configuration_pegasus.py | 55 ++++++------------- .../convert_pegasus_tf_to_pytorch.py | 22 +++++--- tests/test_modeling_pegasus.py | 39 ++++++------- 4 files changed, 48 insertions(+), 70 deletions(-) diff --git a/examples/seq2seq/finetune_pegasus_xsum.sh b/examples/seq2seq/finetune_pegasus_xsum.sh index bdd4d6f9ad..ec7ff98557 100755 --- a/examples/seq2seq/finetune_pegasus_xsum.sh +++ b/examples/seq2seq/finetune_pegasus_xsum.sh @@ -10,5 +10,5 @@ python finetune.py \ --n_val 1000 \ --val_check_interval 0.25 \ --max_source_length 512 --max_target_length 56 \ - --freeze_embeds --max_target_length 56 --label_smoothing 0.1 \ + --freeze_embeds --label_smoothing 0.1 --adafactor --task summarization_xsum \ "$@" diff --git a/src/transformers/configuration_pegasus.py b/src/transformers/configuration_pegasus.py index 4c3564fd10..694759b7ed 100644 --- a/src/transformers/configuration_pegasus.py +++ b/src/transformers/configuration_pegasus.py @@ -47,46 +47,23 @@ DEFAULTS = dict( activation_function="relu", ) # Config values that vary between checkpoints: for testing and conversion -max_gen_length = { - # See appendix C of paper - "xsum": 64, - "cnn_dailymail": 128, - "newsroom": 128, - "wikihow": 256, - "multi_news": 256, - "reddit_tifu": 128, - "big_patent": 256, - "arxiv": 256, - "pubmed": 256, - "gigaword": 32, - "aeslc": 32, - "billsum": 256, - "large": 256, # @sshleifer chose arbitrarily +task_specific_params = { + # These are task specific params for pegasus-large and normal params for finetuned checkpoints + "summarization_xsum": {"length_penalty": 0.8, "max_length": 64, "max_position_embeddings": 512}, + "summarization_cnn_dailymail": {"length_penalty": 0.8, "max_length": 128, "max_position_embeddings": 1024}, + "summarization_newsroom": {"length_penalty": 0.8, "max_length": 128, "max_position_embeddings": 512}, + "summarization_wikihow": {"length_penalty": 0.6, "max_length": 256, "max_position_embeddings": 512}, + "summarization_multi_news": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_reddit_tifu": {"length_penalty": 0.6, "max_length": 128, "max_position_embeddings": 512}, + "summarization_big_patent": {"length_penalty": 0.7, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_arxiv": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_pubmed": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_gigaword": {"length_penalty": 0.6, "max_length": 32, "max_position_embeddings": 128}, + "summarization_aeslc": {"length_penalty": 0.6, "max_length": 32, "max_position_embeddings": 512}, + "summarization_billsum": {"length_penalty": 0.6, "max_length": 256, "max_position_embeddings": 1024}, + # this last entry is useless -- just for consistency + "summarization_large": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, } -max_model_length = { - "xsum": 512, - "cnn_dailymail": 1024, - "newsroom": 512, - "wikihow": 512, - "multi_news": 1024, - "reddit_tifu": 512, - "big_patent": 1024, - "arxiv": 1024, - "pubmed": 1024, - "gigaword": 128, - "aeslc": 512, - "billsum": 1024, - "large": 1024, -} -expected_alpha = { - "multinews": 0.9, - "wikihow": 0.6, - "reddit_tifu": 0.6, - "big_patent": 0.7, - "gigaword": 0.6, - "aeslc": 0.6, - "billsum": 0.6, -} # otherwise 0.8 @add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC) diff --git a/src/transformers/convert_pegasus_tf_to_pytorch.py b/src/transformers/convert_pegasus_tf_to_pytorch.py index e3b8614d4e..edf0498f37 100644 --- a/src/transformers/convert_pegasus_tf_to_pytorch.py +++ b/src/transformers/convert_pegasus_tf_to_pytorch.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import os from pathlib import Path from typing import Dict @@ -22,7 +23,7 @@ import torch from tqdm import tqdm from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer -from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length +from transformers.configuration_pegasus import DEFAULTS, task_specific_params PATTERNS = [ @@ -101,23 +102,25 @@ def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict: return tf_weights -def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir): +def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): # save tokenizer first dataset = Path(ckpt_path).parent.name - desired_max_model_length = max_model_length[dataset] + desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"] tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length) assert tok.model_max_length == desired_max_model_length tok.save_pretrained(save_dir) # convert model tf_weights = get_tf_weights_as_numpy(ckpt_path) - cfg_updates = dict( - max_length=max_gen_length[dataset], - length_penalty=expected_alpha.get(dataset, 0.8), - max_position_embeddings=desired_max_model_length, - ) + cfg_updates = task_specific_params[f"summarization_{dataset}"] + if dataset == "large": + cfg_updates["task_specific_params"] = task_specific_params torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates) torch_model.save_pretrained(save_dir) + sd = torch_model.state_dict() + sd.pop("model.decoder.embed_positions.weight") + sd.pop("model.encoder.embed_positions.weight") + torch.save(sd, Path(save_dir) / "pytorch_model.bin") if __name__ == "__main__": @@ -127,5 +130,6 @@ if __name__ == "__main__": parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.") args = parser.parse_args() if args.save_dir is None: - args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}" + dataset = Path(args.tf_ckpt_path).parent.name + args.save_dir = os.path.join("pegasus", dataset) convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir) diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 6fb387daa7..68cb5d6e04 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -1,9 +1,10 @@ import unittest from transformers import AutoConfig, AutoTokenizer, is_torch_available -from transformers.configuration_pegasus import max_gen_length, max_model_length +from transformers.configuration_pegasus import task_specific_params from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils.logging import ERROR, set_verbosity from .test_modeling_bart import PGE_ARTICLE from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest @@ -14,6 +15,8 @@ if is_torch_available(): XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """ +set_verbosity(ERROR) + @require_torch class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): @@ -50,31 +53,25 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): class PegasusConfigTests(unittest.TestCase): - def test_all_config_max_lengths(self): + @slow + def test_task_specific_params(self): + """Test that task_specific params['summarization_xsum'] == config['pegasus_xsum'] """ failures = [] pegasus_prefix = "google/pegasus" - for dataset, max_len in max_gen_length.items(): + n_prefix_chars = len("summarization_") + for task, desired_settings in task_specific_params.items(): + dataset = task[n_prefix_chars:] mname = f"{pegasus_prefix}-{dataset}" cfg = AutoConfig.from_pretrained(mname) - - if cfg.max_length != max_len: - failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}") - - if cfg.max_position_embeddings < max_model_length[dataset]: - # otherwise you get IndexError for e.g. position 513 - # see https://github.com/huggingface/transformers/issues/6599 - failures.append( - f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}" - ) - + for k, v in desired_settings.items(): + actual_value = getattr(cfg, k) + if actual_value != v: + failures.append(f"config for {mname} had {k}: {actual_value}, expected {v}") tokenizer = AutoTokenizer.from_pretrained(mname) - if max_model_length[dataset] != tokenizer.model_max_length: - failures.append( - f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}" - ) + n_pos_embeds = desired_settings["max_position_embeddings"] + if n_pos_embeds != tokenizer.model_max_length: + failures.append(f"tokenizer.model_max_length {tokenizer.model_max_length} expected {n_pos_embeds}") - if failures == []: - return # error all_fails = "\n".join(failures) - raise AssertionError(f"The following configs have unexpected settings: {all_fails}") + assert not failures, f"The following configs have unexpected settings: {all_fails}"