AutoTokenizer supports mbart-large-en-ro (#5121)
This commit is contained in:
@@ -19,7 +19,7 @@ import logging
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
|
||||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||||
@@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||||||
("camembert", CamembertConfig,),
|
("camembert", CamembertConfig,),
|
||||||
("xlm-roberta", XLMRobertaConfig,),
|
("xlm-roberta", XLMRobertaConfig,),
|
||||||
("marian", MarianConfig,),
|
("marian", MarianConfig,),
|
||||||
|
("mbart", MBartConfig,),
|
||||||
("bart", BartConfig,),
|
("bart", BartConfig,),
|
||||||
("reformer", ReformerConfig,),
|
("reformer", ReformerConfig,),
|
||||||
("longformer", LongformerConfig,),
|
("longformer", LongformerConfig,),
|
||||||
|
|||||||
@@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig):
|
|||||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MBartConfig(BartConfig):
|
||||||
|
model_type = "mbart"
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .configuration_auto import (
|
|||||||
FlaubertConfig,
|
FlaubertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
LongformerConfig,
|
LongformerConfig,
|
||||||
|
MBartConfig,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
ReformerConfig,
|
ReformerConfig,
|
||||||
RetriBertConfig,
|
RetriBertConfig,
|
||||||
@@ -43,7 +44,7 @@ from .configuration_auto import (
|
|||||||
from .configuration_marian import MarianConfig
|
from .configuration_marian import MarianConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .tokenization_albert import AlbertTokenizer
|
from .tokenization_albert import AlbertTokenizer
|
||||||
from .tokenization_bart import BartTokenizer
|
from .tokenization_bart import BartTokenizer, MBartTokenizer
|
||||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||||
from .tokenization_camembert import CamembertTokenizer
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
@@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||||
(AlbertConfig, (AlbertTokenizer, None)),
|
(AlbertConfig, (AlbertTokenizer, None)),
|
||||||
(CamembertConfig, (CamembertTokenizer, None)),
|
(CamembertConfig, (CamembertTokenizer, None)),
|
||||||
|
(MBartConfig, (MBartTokenizer, None)),
|
||||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||||
(MarianConfig, (MarianTokenizer, None)),
|
(MarianConfig, (MarianTokenizer, None)),
|
||||||
(BartConfig, (BartTokenizer, None)),
|
(BartConfig, (BartTokenizer, None)),
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BartModel,
|
BartModel,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
@@ -38,7 +39,6 @@ if is_torch_available():
|
|||||||
BartForQuestionAnswering,
|
BartForQuestionAnswering,
|
||||||
BartConfig,
|
BartConfig,
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
MBartTokenizer,
|
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
@@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
||||||
cls.pad_token_id = 1
|
cls.pad_token_id = 1
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
"""Only load the model if needed."""
|
"""Only load the model if needed."""
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
|
||||||
if "cuda" in torch_device:
|
if "cuda" in torch_device:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user