best current version and make style
This commit is contained in:
committed by
Patrick von Platen
parent
c62444da39
commit
2acfe63964
@@ -104,7 +104,9 @@ def prepare_bart_inputs_dict(
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
# TODO(SS): fix the below in a separate PR
|
||||
@@ -451,9 +453,9 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
|
||||
|
||||
dct = tok.batch_encode_plus(
|
||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
[IRAN_ARTICLE, ARTICLE_SUBWAY],
|
||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
|
||||
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
|
||||
max_length=1024,
|
||||
pad_to_max_length=True,
|
||||
return_tensors="pt",
|
||||
@@ -472,17 +474,17 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
min_length=min_length + 1,
|
||||
no_repeat_ngram_size=3,
|
||||
do_sample=False,
|
||||
early_stopping=True
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
|
||||
decoded = [
|
||||
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
||||
]
|
||||
|
||||
self.assertListEqual(
|
||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
|
||||
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
|
||||
decoded,
|
||||
)
|
||||
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
||||
|
||||
Reference in New Issue
Block a user