Bart-CNN (#3059)
`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
@@ -26,7 +26,7 @@ _bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large": _bart_large_url,
|
||||
"bart-large-mnli": _bart_large_url, # fine as same
|
||||
"bart-cnn": None, # not done
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
|
||||
classifier_dropout=0.0,
|
||||
output_past=False,
|
||||
num_labels=3,
|
||||
bos_token_id=0,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -67,12 +68,16 @@ class BartConfig(PretrainedConfig):
|
||||
config = BartConfig.from_pretrained('bart-large')
|
||||
model = BartModel(config)
|
||||
"""
|
||||
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
|
||||
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
output_past=output_past,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
**common_kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
|
||||
Reference in New Issue
Block a user