Add BloomForQuestionAnswering (#19310)

* add bloom for question answering

- attempt to add Bloom for question answering
- adapted from `GPTJForQuestionAnswering`
- Fixed `num_labels` to `2` for common tests
- Added a bit of docstring
- All common tests pass

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* revert changes related to `num_labels`

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2022-10-04 17:52:13 +02:00
committed by GitHub
parent 6dce9e0cdd
commit 587d84b178
7 changed files with 120 additions and 0 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
from transformers import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForQuestionAnswering,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
@@ -274,6 +275,14 @@ class BloomModelTester:
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_question_answering_model(self, config, input_ids, input_mask, *args):
model = BloomForQuestionAnswering(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, *args, gradient_checkpointing=False
):
@@ -314,6 +323,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomForQuestionAnswering,
)
if is_torch_available()
else ()