BartForQuestionAnswering (#4908)
This commit is contained in:
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartForQuestionAnswering,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
@@ -375,6 +376,19 @@ class BartHeadTests(unittest.TestCase):
|
||||
loss = outputs[0]
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
def test_question_answering_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
|
||||
model = BartForQuestionAnswering(config)
|
||||
model.to(torch_device)
|
||||
loss, start_logits, end_logits, _ = model(
|
||||
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
|
||||
)
|
||||
|
||||
self.assertEqual(start_logits.shape, input_ids.shape)
|
||||
self.assertEqual(end_logits.shape, input_ids.shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
@timeout_decorator.timeout(1)
|
||||
def test_lm_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
|
||||
Reference in New Issue
Block a user