best current version and make style

This commit is contained in:
patrickvonplaten
2020-03-06 22:19:01 +01:00
committed by Patrick von Platen
parent c62444da39
commit 2acfe63964
4 changed files with 45 additions and 36 deletions

View File

@@ -104,7 +104,9 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
# TODO(SS): fix the below in a separate PR
@@ -451,9 +453,9 @@ class BartModelIntegrationTest(unittest.TestCase):
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
dct = tok.batch_encode_plus(
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
[IRAN_ARTICLE, ARTICLE_SUBWAY],
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
max_length=1024,
pad_to_max_length=True,
return_tensors="pt",
@@ -472,17 +474,17 @@ class BartModelIntegrationTest(unittest.TestCase):
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True
early_stopping=True,
)
decoded = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
self.assertListEqual(
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
decoded,
)
# TODO(SS): run fairseq again with num_beams=2, min_len=20.