fix bart tests (#10060)
This commit is contained in:
committed by
GitHub
parent
b01483faa0
commit
9a0399e18d
@@ -42,7 +42,6 @@ if is_torch_available():
|
|||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
BartModel,
|
BartModel,
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
BartTokenizerFast,
|
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right
|
from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right
|
||||||
@@ -566,10 +565,6 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
def default_tokenizer(self):
|
def default_tokenizer(self):
|
||||||
return BartTokenizer.from_pretrained("facebook/bart-large")
|
return BartTokenizer.from_pretrained("facebook/bart-large")
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def default_tokenizer_fast(self):
|
|
||||||
return BartTokenizerFast.from_pretrained("facebook/bart-large")
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
|
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
|
||||||
@@ -589,14 +584,14 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
|
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
|
||||||
src_text = [" I went to the <mask>."]
|
src_text = [" I went to the <mask>."]
|
||||||
results = [x["token_str"] for x in pbase(src_text)]
|
results = [x["token_str"] for x in pbase(src_text)]
|
||||||
assert "Ġbathroom" in results
|
assert " bathroom" in results
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_mask_filling(self):
|
def test_large_mask_filling(self):
|
||||||
plarge = pipeline(task="fill-mask", model="facebook/bart-large")
|
plarge = pipeline(task="fill-mask", model="facebook/bart-large")
|
||||||
src_text = [" I went to the <mask>."]
|
src_text = [" I went to the <mask>."]
|
||||||
results = [x["token_str"] for x in plarge(src_text)]
|
results = [x["token_str"] for x in plarge(src_text)]
|
||||||
expected_results = ["Ġbathroom", "Ġgym", "Ġwrong", "Ġmovies", "Ġhospital"]
|
expected_results = [" bathroom", " gym", " wrong", " movies", " hospital"]
|
||||||
self.assertListEqual(results, expected_results)
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user