Multilingual BART - (#3602)

- support mbart-en-ro weights
- add MBartTokenizer
This commit is contained in:
Sam Shleifer
2020-04-10 11:25:39 -04:00
committed by GitHub
parent f98d0ef2a2
commit 7a7fdf71f8
7 changed files with 232 additions and 38 deletions

View File

@@ -27,6 +27,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
}
@@ -61,6 +62,9 @@ class BartConfig(PretrainedConfig):
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
normalize_before=False,
add_final_layer_norm=False,
scale_embedding=False,
**common_kwargs
):
r"""
@@ -90,6 +94,11 @@ class BartConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# True for mbart, False otherwise
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
self.add_final_layer_norm = add_final_layer_norm
# 3 Types of Dropout
self.attention_dropout = attention_dropout
@@ -100,9 +109,17 @@ class BartConfig(PretrainedConfig):
self.classif_dropout = classifier_dropout
@property
def num_attention_heads(self):
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self):
def hidden_size(self) -> int:
return self.d_model
def is_valid_mbart(self) -> bool:
"""Is the configuration aligned with the MBART paper."""
if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
return True
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
logger.info("This configuration is a mixture of MBART and BART settings")
return False