`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
Sam Shleifer
2020-03-02 10:35:53 -05:00
committed by GitHub
parent 6b1ff25084
commit b54ef78d0c
8 changed files with 544 additions and 152 deletions

View File

@@ -26,7 +26,7 @@ _bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": _bart_large_url,
"bart-large-mnli": _bart_large_url, # fine as same
"bart-cnn": None, # not done
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
}
@@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
**common_kwargs
):
r"""
@@ -67,12 +68,16 @@ class BartConfig(PretrainedConfig):
config = BartConfig.from_pretrained('bart-large')
model = BartModel(config)
"""
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
super().__init__(
num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads