Multilingual BART - (#3602)
- support mbart-en-ro weights - add MBartTokenizer
This commit is contained in:
@@ -27,6 +27,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -61,6 +62,9 @@ class BartConfig(PretrainedConfig):
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
normalize_before=False,
|
||||
add_final_layer_norm=False,
|
||||
scale_embedding=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -90,6 +94,11 @@ class BartConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
self.activation_function = activation_function
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
# True for mbart, False otherwise
|
||||
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
|
||||
self.add_final_layer_norm = add_final_layer_norm
|
||||
|
||||
# 3 Types of Dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
@@ -100,9 +109,17 @@ class BartConfig(PretrainedConfig):
|
||||
self.classif_dropout = classifier_dropout
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
def is_valid_mbart(self) -> bool:
|
||||
"""Is the configuration aligned with the MBART paper."""
|
||||
if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
|
||||
return True
|
||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user