[Bart] Question Answering Model is added to tests (#5024)
* fix test * Update tests/test_modeling_common.py * Update tests/test_modeling_common.py
This commit is contained in:
committed by
GitHub
parent
bbad4c6989
commit
ebba39e4e1
@@ -113,7 +113,9 @@ def prepare_bart_inputs_dict(
|
|||||||
@require_torch
|
@require_torch
|
||||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
BertConfig,
|
BertConfig,
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
top_k_top_p_filtering,
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -180,8 +181,13 @@ class ModelTesterMixin:
|
|||||||
correct_outlen = 4
|
correct_outlen = 4
|
||||||
decoder_attention_idx = 1
|
decoder_attention_idx = 1
|
||||||
|
|
||||||
if "lm_labels" in inputs_dict: # loss will come first
|
# loss is at first position
|
||||||
correct_outlen += 1 # compute loss
|
if "labels" in inputs_dict:
|
||||||
|
correct_outlen += 1 # loss is added to beginning
|
||||||
|
decoder_attention_idx += 1
|
||||||
|
# Question Answering model returns start_logits and end_logits
|
||||||
|
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||||
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
decoder_attention_idx += 1
|
decoder_attention_idx += 1
|
||||||
self.assertEqual(out_len, correct_outlen)
|
self.assertEqual(out_len, correct_outlen)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user