Regression test for pegasus bugfix (#6606)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, is_torch_available
|
||||
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
||||
from transformers.configuration_pegasus import max_gen_length, max_model_length
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
@@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
class PegasusConfigTests(unittest.TestCase):
|
||||
def test_all_config_max_lengths(self):
|
||||
expected_max_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
}
|
||||
failures = []
|
||||
pegasus_prefix = "google/pegasus"
|
||||
for dataset, max_len in expected_max_length.items():
|
||||
for dataset, max_len in max_gen_length.items():
|
||||
mname = f"{pegasus_prefix}-{dataset}"
|
||||
cfg = AutoConfig.from_pretrained(mname)
|
||||
|
||||
if cfg.max_length != max_len:
|
||||
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
|
||||
|
||||
if cfg.max_position_embeddings < max_model_length[dataset]:
|
||||
# otherwise you get IndexError for e.g. position 513
|
||||
# see https://github.com/huggingface/transformers/issues/6599
|
||||
failures.append(
|
||||
f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(mname)
|
||||
if max_model_length[dataset] != tokenizer.model_max_length:
|
||||
failures.append(
|
||||
f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}"
|
||||
)
|
||||
|
||||
if failures == []:
|
||||
return
|
||||
# error
|
||||
|
||||
Reference in New Issue
Block a user