[T5] Add T5ForQuestionAnswering and MT5ForQuestionAnswering (#24481)
* Adding T5ForQuestionAnswering * Changed weight initialization that results in better initial loss when fine-tuning * Update to class variables * Running make fixup * Running make fix-copies * Remove model_parallel * Adding MT5ForQuestionAnswering * Adding docs * Fix wrong doc * Update src/transformers/models/mt5/modeling_mt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/t5/modeling_t5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * File formatting * Undoing change --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ if is_torch_available():
|
||||
ByT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5Model,
|
||||
T5Tokenizer,
|
||||
)
|
||||
@@ -520,7 +521,7 @@ class T5ModelTester:
|
||||
|
||||
@require_torch
|
||||
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration, T5ForQuestionAnswering) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@@ -529,6 +530,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
"summarization": T5ForConditionalGeneration,
|
||||
"text2text-generation": T5ForConditionalGeneration,
|
||||
"translation": T5ForConditionalGeneration,
|
||||
"question-answering": T5ForQuestionAnswering,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
||||
Reference in New Issue
Block a user