From 5bf4465e6c795131fbad2695bd80dae889247e1d Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 20 Aug 2020 15:34:43 -0400 Subject: [PATCH] Regression test for pegasus bugfix (#6606) --- src/transformers/configuration_pegasus.py | 46 ++++++++++++++++-- .../convert_pegasus_tf_to_pytorch.py | 48 +++---------------- src/transformers/modeling_pegasus.py | 7 +++ tests/test_modeling_pegasus.py | 35 +++++++------- 4 files changed, 73 insertions(+), 63 deletions(-) diff --git a/src/transformers/configuration_pegasus.py b/src/transformers/configuration_pegasus.py index e17ead2b27..93c6736d03 100644 --- a/src/transformers/configuration_pegasus.py +++ b/src/transformers/configuration_pegasus.py @@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable logger = logging.getLogger(__name__) +# These config values do not vary between checkpoints DEFAULTS = dict( vocab_size=96103, max_position_embeddings=512, @@ -46,6 +47,47 @@ DEFAULTS = dict( num_beams=8, activation_function="relu", ) +# Config values that vary between checkpoints: for testing and conversion +max_gen_length = { + # See appendix C of paper + "xsum": 64, + "cnn_dailymail": 128, + "newsroom": 128, + "wikihow": 256, + "multi_news": 256, + "reddit_tifu": 128, + "big_patent": 256, + "arxiv": 256, + "pubmed": 256, + "gigaword": 32, + "aeslc": 32, + "billsum": 256, + "large": 256, # @sshleifer chose arbitrarily +} +max_model_length = { + "xsum": 512, + "cnn_dailymail": 1024, + "newsroom": 512, + "wikihow": 512, + "multi_news": 1024, + "reddit_tifu": 512, + "big_patent": 1024, + "arxiv": 1024, + "pubmed": 1024, + "gigaword": 128, + "aeslc": 512, + "billsum": 1024, + "large": 1024, +} +expected_alpha = { + "multinews": 0.9, + "wikihow": 0.6, + "reddit_tifu": 0.6, + "big_patent": 0.7, + "gigaword": 0.6, + "aeslc": 0.6, + "billsum": 0.6, +} # otherwise 0.8 @add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC) @@ -56,7 +98,3 @@ class PegasusConfig(BartConfig): """ model_type = "pegasus" # The implementation of the config object is in BartConfig - - @property - def default_config_parameters(self): - return DEFAULTS diff --git a/src/transformers/convert_pegasus_tf_to_pytorch.py b/src/transformers/convert_pegasus_tf_to_pytorch.py index 719ca5f04e..e3b8614d4e 100644 --- a/src/transformers/convert_pegasus_tf_to_pytorch.py +++ b/src/transformers/convert_pegasus_tf_to_pytorch.py @@ -22,7 +22,7 @@ import torch from tqdm import tqdm from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer -from transformers.configuration_pegasus import DEFAULTS +from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length PATTERNS = [ @@ -52,47 +52,7 @@ def rename_state_dict_key(k): # See appendix C of paper for all hyperparams -max_gen_length = { - # See appendix C of paper - "xsum": 64, - "cnn_dailymail": 128, - "newsroom": 128, - "wikihow": 256, - "multi_news": 256, - "reddit_tifu": 128, - "big_patent": 256, - "arxiv": 256, - "pubmed": 256, - "gigaword": 32, - "aeslc": 32, - "billsum": 256, - "large": 256, # @sshleifer chose arbitrarily -} -max_model_length = { - "xsum": 512, - "cnn_dailymail": 1024, - "newsroom": 512, - "wikihow": 512, - "multi_news": 1024, - "reddit_tifu": 512, - "big_patent": 1024, - "arxiv": 1024, - "pubmed": 1024, - "gigaword": 128, - "aeslc": 512, - "billsum": 1024, - "large": 1024, -} -expected_alpha = { - "multinews": 0.9, - "wikihow": 0.6, - "reddit_tifu": 0.6, - "big_patent": 0.7, - "gigaword": 0.6, - "aeslc": 0.6, - "billsum": 0.6, -} # otherwise 0.8 # TODO(SS): one constant @@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir): # convert model tf_weights = get_tf_weights_as_numpy(ckpt_path) - cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8)) + cfg_updates = dict( + max_length=max_gen_length[dataset], + length_penalty=expected_alpha.get(dataset, 0.8), + max_position_embeddings=desired_max_model_length, + ) torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates) torch_model.save_pretrained(save_dir) diff --git a/src/transformers/modeling_pegasus.py b/src/transformers/modeling_pegasus.py index 785dd375a1..88b0f77f12 100644 --- a/src/transformers/modeling_pegasus.py +++ b/src/transformers/modeling_pegasus.py @@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration @add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING) class PegasusForConditionalGeneration(BartForConditionalGeneration): config_class = PegasusConfig + authorized_missing_keys = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + r"model.encoder.embed_positions", + "model.decoder.embed_positions", + ] r""" Pytorch version of google's pegasus model for summarization. Model API is identical to BartForConditionalGeneration. diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index b11f4f2b36..ac6be27210 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -1,6 +1,7 @@ import unittest -from transformers import AutoConfig, is_torch_available +from transformers import AutoConfig, AutoTokenizer, is_torch_available +from transformers.configuration_pegasus import max_gen_length, max_model_length from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device @@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): class PegasusConfigTests(unittest.TestCase): def test_all_config_max_lengths(self): - expected_max_length = { - # See appendix C of paper - "xsum": 64, - "cnn_dailymail": 128, - "newsroom": 128, - "wikihow": 256, - "multi_news": 256, - "reddit_tifu": 128, - "big_patent": 256, - "arxiv": 256, - "pubmed": 256, - "gigaword": 32, - "aeslc": 32, - "billsum": 256, - } failures = [] pegasus_prefix = "google/pegasus" - for dataset, max_len in expected_max_length.items(): + for dataset, max_len in max_gen_length.items(): mname = f"{pegasus_prefix}-{dataset}" cfg = AutoConfig.from_pretrained(mname) + if cfg.max_length != max_len: failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}") + + if cfg.max_position_embeddings < max_model_length[dataset]: + # otherwise you get IndexError for e.g. position 513 + # see https://github.com/huggingface/transformers/issues/6599 + failures.append( + f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}" + ) + + tokenizer = AutoTokenizer.from_pretrained(mname) + if max_model_length[dataset] != tokenizer.model_max_length: + failures.append( + f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}" + ) + if failures == []: return # error