From 4d1a3ffde810ecaaf641a48b7e449295535c1e85 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 1 Sep 2020 21:56:39 +0200 Subject: [PATCH] [EncoderDecoder] Add xlm-roberta to encoder decoder (#6878) * finish xlm-roberta * finish docs * expose XLMRobertaForCausalLM --- docs/source/model_doc/xlmroberta.rst | 9 ++++++++- src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 ++ src/transformers/modeling_xlm_roberta.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/model_doc/xlmroberta.rst b/docs/source/model_doc/xlmroberta.rst index c4c27d6420..e7ccecfdbd 100644 --- a/docs/source/model_doc/xlmroberta.rst +++ b/docs/source/model_doc/xlmroberta.rst @@ -56,6 +56,13 @@ XLMRobertaModel :members: +XLMRobertaForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.XLMRobertaForCausalLM + :members: + + XLMRobertaForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -130,4 +137,4 @@ TFXLMRobertaForQuestionAnswering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFXLMRobertaForQuestionAnswering - :members: \ No newline at end of file + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e5540d864a..5e8283a402 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -418,6 +418,7 @@ if is_torch_available(): ) from .modeling_xlm_roberta import ( XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + XLMRobertaForCausalLM, XLMRobertaForMaskedLM, XLMRobertaForMultipleChoice, XLMRobertaForQuestionAnswering, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 287aa06778..30ae410cd1 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -156,6 +156,7 @@ from .modeling_xlm import ( XLMWithLMHeadModel, ) from .modeling_xlm_roberta import ( + XLMRobertaForCausalLM, XLMRobertaForMaskedLM, XLMRobertaForMultipleChoice, XLMRobertaForQuestionAnswering, @@ -255,6 +256,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ (CamembertConfig, CamembertForCausalLM), + (XLMRobertaConfig, XLMRobertaForCausalLM), (RobertaConfig, RobertaForCausalLM), (BertConfig, BertLMHeadModel), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), diff --git a/src/transformers/modeling_xlm_roberta.py b/src/transformers/modeling_xlm_roberta.py index 31bd816844..e5375bc043 100644 --- a/src/transformers/modeling_xlm_roberta.py +++ b/src/transformers/modeling_xlm_roberta.py @@ -18,6 +18,7 @@ from .configuration_xlm_roberta import XLMRobertaConfig from .file_utils import add_start_docstrings from .modeling_roberta import ( + RobertaForCausalLM, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForQuestionAnswering, @@ -67,6 +68,19 @@ class XLMRobertaModel(RobertaModel): config_class = XLMRobertaConfig +@add_start_docstrings( + "XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.", + XLM_ROBERTA_START_DOCSTRING, +) +class XLMRobertaForCausalLM(RobertaForCausalLM): + """ + This class overrides :class:`~transformers.RobertaForCausalLM`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig + + @add_start_docstrings( """XLM-RoBERTa Model with a `language modeling` head on top. """, XLM_ROBERTA_START_DOCSTRING,