From d0e42a7bed3de9271ae39c575d7eeb54cf985921 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 21 Aug 2020 17:22:54 +0530 Subject: [PATCH] CamembertForCausalLM (#6577) * added CamembertForCausalLM * add in __init__ and auto model * style * doc --- docs/source/model_doc/camembert.rst | 7 +++++++ src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 ++ src/transformers/modeling_camembert.py | 13 +++++++++++++ 4 files changed, 23 insertions(+) diff --git a/docs/source/model_doc/camembert.rst b/docs/source/model_doc/camembert.rst index 5ccdfe5b87..8f0d578848 100644 --- a/docs/source/model_doc/camembert.rst +++ b/docs/source/model_doc/camembert.rst @@ -49,6 +49,13 @@ CamembertModel :members: +CamembertForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.CamembertForCausalLM + :members: + + CamembertForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f789535f08..f2f4d7c2d1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -331,6 +331,7 @@ if is_torch_available(): CamembertForMultipleChoice, CamembertForTokenClassification, CamembertForQuestionAnswering, + CamembertForCausalLM, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, ) from .modeling_encoder_decoder import EncoderDecoderModel diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 02088f565c..4f56e473c5 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -73,6 +73,7 @@ from .modeling_bert import ( BertModel, ) from .modeling_camembert import ( + CamembertForCausalLM, CamembertForMaskedLM, CamembertForMultipleChoice, CamembertForQuestionAnswering, @@ -253,6 +254,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ + (CamembertConfig, CamembertForCausalLM), (RobertaConfig, RobertaForCausalLM), (BertConfig, BertLMHeadModel), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), diff --git a/src/transformers/modeling_camembert.py b/src/transformers/modeling_camembert.py index 2e9a24d4d2..797b4f06e1 100644 --- a/src/transformers/modeling_camembert.py +++ b/src/transformers/modeling_camembert.py @@ -20,6 +20,7 @@ import logging from .configuration_camembert import CamembertConfig from .file_utils import add_start_docstrings from .modeling_roberta import ( + RobertaForCausalLM, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForQuestionAnswering, @@ -133,3 +134,15 @@ class CamembertForQuestionAnswering(RobertaForQuestionAnswering): """ config_class = CamembertConfig + + +@add_start_docstrings( + """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning. """, CAMEMBERT_START_DOCSTRING +) +class CamembertForCausalLM(RobertaForCausalLM): + """ + This class overrides :class:`~transformers.RobertaForCausalLM`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = CamembertConfig