AutoTokenizer supports mbart-large-en-ro (#5121)

This commit is contained in:
Sam Shleifer
2020-06-18 20:47:37 -04:00
committed by GitHub
parent 2db1e2f415
commit 84be482f66
4 changed files with 12 additions and 6 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartModel,
BartForConditionalGeneration,
@@ -38,7 +39,6 @@ if is_torch_available():
BartForQuestionAnswering,
BartConfig,
BartTokenizer,
MBartTokenizer,
BatchEncoding,
pipeline,
)
@@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
checkpoint_name = "facebook/mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
def model(self):
"""Only load the model if needed."""
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
if "cuda" in torch_device:
model = model.half()
return model