From 84be482f6698fac822a5113735f2242c6d3abc76 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Jun 2020 20:47:37 -0400 Subject: [PATCH] AutoTokenizer supports mbart-large-en-ro (#5121) --- src/transformers/configuration_auto.py | 3 ++- src/transformers/configuration_bart.py | 4 ++++ src/transformers/tokenization_auto.py | 4 +++- tests/test_modeling_bart.py | 7 +++---- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index e2c436f882..09a76abe26 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig -from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig +from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict( ("camembert", CamembertConfig,), ("xlm-roberta", XLMRobertaConfig,), ("marian", MarianConfig,), + ("mbart", MBartConfig,), ("bart", BartConfig,), ("reformer", ReformerConfig,), ("longformer", LongformerConfig,), diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 00f3337b2e..5f13f8ccb5 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig): if self.normalize_before or self.add_final_layer_norm or self.scale_embedding: logger.info("This configuration is a mixture of MBART and BART settings") return False + + +class MBartConfig(BartConfig): + model_type = "mbart" diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 308136970f..39f4fa3dcb 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -30,6 +30,7 @@ from .configuration_auto import ( FlaubertConfig, GPT2Config, LongformerConfig, + MBartConfig, OpenAIGPTConfig, ReformerConfig, RetriBertConfig, @@ -43,7 +44,7 @@ from .configuration_auto import ( from .configuration_marian import MarianConfig from .configuration_utils import PretrainedConfig from .tokenization_albert import AlbertTokenizer -from .tokenization_bart import BartTokenizer +from .tokenization_bart import BartTokenizer, MBartTokenizer from .tokenization_bert import BertTokenizer, BertTokenizerFast from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_camembert import CamembertTokenizer @@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict( (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), (AlbertConfig, (AlbertTokenizer, None)), (CamembertConfig, (CamembertTokenizer, None)), + (MBartConfig, (MBartTokenizer, None)), (XLMRobertaConfig, (XLMRobertaTokenizer, None)), (MarianConfig, (MarianTokenizer, None)), (BartConfig, (BartTokenizer, None)), diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 8ceee5e263..208c896cff 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -31,6 +31,7 @@ if is_torch_available(): from transformers import ( AutoModel, AutoModelForSequenceClassification, + AutoModelForSeq2SeqLM, AutoTokenizer, BartModel, BartForConditionalGeneration, @@ -38,7 +39,6 @@ if is_torch_available(): BartForQuestionAnswering, BartConfig, BartTokenizer, - MBartTokenizer, BatchEncoding, pipeline, ) @@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase): @classmethod def setUpClass(cls): checkpoint_name = "facebook/mbart-large-en-ro" - cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name) + cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name) cls.pad_token_id = 1 return cls @cached_property def model(self): """Only load the model if needed.""" - - model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device) + model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device) if "cuda" in torch_device: model = model.half() return model