[Bart] Question Answering Model is added to tests (#5024)

* fix test

* Update tests/test_modeling_common.py

* Update tests/test_modeling_common.py
This commit is contained in:
Patrick von Platen
2020-06-15 22:50:09 +02:00
committed by GitHub
parent bbad4c6989
commit ebba39e4e1
2 changed files with 11 additions and 3 deletions

View File

@@ -38,6 +38,7 @@ if is_torch_available():
BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
top_k_top_p_filtering,
)
@@ -180,8 +181,13 @@ class ModelTesterMixin:
correct_outlen = 4
decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
decoder_attention_idx += 1
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
decoder_attention_idx += 1
self.assertEqual(out_len, correct_outlen)