From 41a1d27cdefd6417c298518198f99e3b8431a5c0 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 8 Jun 2020 21:22:37 -0400 Subject: [PATCH] Add XLMRobertaForQuestionAnswering (#4855) * Add XLMRobertaForQuestionAnswering * Formatting * Make test happy --- docs/source/model_doc/xlmroberta.rst | 7 +++++++ src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 ++ src/transformers/modeling_xlm_roberta.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+) diff --git a/docs/source/model_doc/xlmroberta.rst b/docs/source/model_doc/xlmroberta.rst index 0743b9d308..c4c27d6420 100644 --- a/docs/source/model_doc/xlmroberta.rst +++ b/docs/source/model_doc/xlmroberta.rst @@ -84,6 +84,13 @@ XLMRobertaForTokenClassification :members: +XLMRobertaForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.XLMRobertaForQuestionAnswering + :members: + + TFXLMRobertaModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 25adbaea69..34dd885d86 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -298,6 +298,7 @@ if is_torch_available(): XLMRobertaForMultipleChoice, XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification, + XLMRobertaForQuestionAnswering, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, ) from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 278bdbb39c..ca25c5cb76 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -121,6 +121,7 @@ from .modeling_xlm import ( from .modeling_xlm_roberta import ( XLMRobertaForMaskedLM, XLMRobertaForMultipleChoice, + XLMRobertaForQuestionAnswering, XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification, XLMRobertaModel, @@ -230,6 +231,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( (DistilBertConfig, DistilBertForQuestionAnswering), (AlbertConfig, AlbertForQuestionAnswering), (LongformerConfig, LongformerForQuestionAnswering), + (XLMRobertaConfig, XLMRobertaForQuestionAnswering), (RobertaConfig, RobertaForQuestionAnswering), (BertConfig, BertForQuestionAnswering), (XLNetConfig, XLNetForQuestionAnsweringSimple), diff --git a/src/transformers/modeling_xlm_roberta.py b/src/transformers/modeling_xlm_roberta.py index 132621a731..b76d974440 100644 --- a/src/transformers/modeling_xlm_roberta.py +++ b/src/transformers/modeling_xlm_roberta.py @@ -23,6 +23,7 @@ from .file_utils import add_start_docstrings from .modeling_roberta import ( RobertaForMaskedLM, RobertaForMultipleChoice, + RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel, @@ -120,3 +121,17 @@ class XLMRobertaForTokenClassification(RobertaForTokenClassification): """ config_class = XLMRobertaConfig + + +@add_start_docstrings( + """XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""", + XLM_ROBERTA_START_DOCSTRING, +) +class XLMRobertaForQuestionAnswering(RobertaForQuestionAnswering): + """ + This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig