[BART] add bart-large-xsum weights (#3422)
This commit is contained in:
@@ -450,6 +450,38 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_xsum_summarization_same_as_fairseq(self):
|
||||
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
|
||||
tok = BartTokenizer.from_pretrained("bart-large")
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
|
||||
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)
|
||||
|
||||
hypotheses_batch = model.generate(
|
||||
input_ids=dct["input_ids"].to(torch_device),
|
||||
attention_mask=dct["attention_mask"].to(torch_device),
|
||||
num_beams=2,
|
||||
max_length=62,
|
||||
min_length=11,
|
||||
length_penalty=1.0,
|
||||
no_repeat_ngram_size=3,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=model.config.eos_token_ids[0],
|
||||
)
|
||||
|
||||
decoded = [
|
||||
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
||||
]
|
||||
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
|
||||
|
||||
def test_xsum_config_generation_params(self):
|
||||
config = BartConfig.from_pretrained("bart-large-xsum")
|
||||
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
|
||||
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
|
||||
self.assertDictEqual(expected_params, config_params)
|
||||
|
||||
@slow
|
||||
def test_cnn_summarization_same_as_fairseq(self):
|
||||
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user