Add bart-base (#5014)
This commit is contained in:
@@ -40,6 +40,7 @@ if is_torch_available():
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
BatchEncoding,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -565,6 +566,22 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@slow
|
||||
def test_bart_base_mask_filling(self):
|
||||
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
|
||||
src_text = [" I went to the <mask>."]
|
||||
results = [x["token_str"] for x in pbase(src_text)]
|
||||
expected_results = ["Ġbathroom", "Ġrestroom", "Ġhospital", "Ġkitchen", "Ġcar"]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
||||
@slow
|
||||
def test_bart_large_mask_filling(self):
|
||||
pbase = pipeline(task="fill-mask", model="facebook/bart-large")
|
||||
src_text = [" I went to the <mask>."]
|
||||
results = [x["token_str"] for x in pbase(src_text)]
|
||||
expected_results = ["Ġbathroom", "Ġgym", "Ġwrong", "Ġmovies", "Ġhospital"]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
||||
@slow
|
||||
def test_mnli_inference(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user