[Bart] Question Answering Model is added to tests (#5024)
* fix test * Update tests/test_modeling_common.py * Update tests/test_modeling_common.py
This commit is contained in:
committed by
GitHub
parent
bbad4c6989
commit
ebba39e4e1
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
BertConfig,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
@@ -180,8 +181,13 @@ class ModelTesterMixin:
|
||||
correct_outlen = 4
|
||||
decoder_attention_idx = 1
|
||||
|
||||
if "lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
decoder_attention_idx += 1
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
decoder_attention_idx += 1
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user