AutoTokenizer supports mbart-large-en-ro (#5121)
This commit is contained in:
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
@@ -38,7 +39,6 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
BatchEncoding,
|
||||
pipeline,
|
||||
)
|
||||
@@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
if "cuda" in torch_device:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user