BartForQuestionAnswering (#4908)

This commit is contained in:
Suraj Patil
2020-06-13 01:17:57 +05:30
committed by GitHub
parent 538531cde5
commit e93ccb3290
5 changed files with 146 additions and 1 deletions

View File

@@ -35,6 +35,7 @@ if is_torch_available():
BartModel,
BartForConditionalGeneration,
BartForSequenceClassification,
BartForQuestionAnswering,
BartConfig,
BartTokenizer,
MBartTokenizer,
@@ -375,6 +376,19 @@ class BartHeadTests(unittest.TestCase):
loss = outputs[0]
self.assertIsInstance(loss.item(), float)
def test_question_answering_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
model = BartForQuestionAnswering(config)
model.to(torch_device)
loss, start_logits, end_logits, _ = model(
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
)
self.assertEqual(start_logits.shape, input_ids.shape)
self.assertEqual(end_logits.shape, input_ids.shape)
self.assertIsInstance(loss.item(), float)
@timeout_decorator.timeout(1)
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data()