From c749a543fa735422b429a043a66becf0767afcbb Mon Sep 17 00:00:00 2001 From: maximeilluin <60709375+maximeilluin@users.noreply.github.com> Date: Fri, 21 Feb 2020 18:01:02 +0100 Subject: [PATCH] Added CamembertForQuestionAnswering (#2746) * Added CamembertForQuestionAnswering * fixed camembert tokenizer case --- examples/run_squad.py | 13 ++++++++++--- src/transformers/__init__.py | 1 + src/transformers/data/processors/squad.py | 2 +- src/transformers/modeling_camembert.py | 21 ++++++++++++++++++--- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 4cd555fa73..f94fb22098 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -38,6 +38,9 @@ from transformers import ( BertConfig, BertForQuestionAnswering, BertTokenizer, + CamembertConfig, + CamembertForQuestionAnswering, + CamembertTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer, @@ -70,12 +73,16 @@ except ImportError: logger = logging.getLogger(__name__) ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), + ( + tuple(conf.pretrained_config_archive_map.keys()) + for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig) + ), (), ) MODEL_CLASSES = { "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), + "camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer), "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer), "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), @@ -212,7 +219,7 @@ def train(args, train_dataset, model, tokenizer): "end_positions": batch[4], } - if args.model_type in ["xlm", "roberta", "distilbert"]: + if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: del inputs["token_type_ids"] if args.model_type in ["xlnet", "xlm"]: @@ -327,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""): "token_type_ids": batch[2], } - if args.model_type in ["xlm", "roberta", "distilbert"]: + if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: del inputs["token_type_ids"] example_indices = batch[3] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 93bd94c622..ebacba6fdc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -221,6 +221,7 @@ if is_torch_available(): CamembertModel, CamembertForSequenceClassification, CamembertForTokenClassification, + CamembertForQuestionAnswering, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ) from .modeling_distilbert import ( diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index 3eaec5f23d..e6e6a589a1 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -123,7 +123,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) sequence_added_tokens = ( tokenizer.max_len - tokenizer.max_len_single_sentence + 1 - if "roberta" in str(type(tokenizer)) + if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence ) sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair diff --git a/src/transformers/modeling_camembert.py b/src/transformers/modeling_camembert.py index 12877dff16..8f0f06872b 100644 --- a/src/transformers/modeling_camembert.py +++ b/src/transformers/modeling_camembert.py @@ -15,7 +15,6 @@ # limitations under the License. """PyTorch CamemBERT model. """ - import logging from .configuration_camembert import CamembertConfig @@ -23,6 +22,7 @@ from .file_utils import add_start_docstrings from .modeling_roberta import ( RobertaForMaskedLM, RobertaForMultipleChoice, + RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel, @@ -37,7 +37,6 @@ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { "umberto-wikipedia-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin", } - CAMEMBERT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. @@ -46,7 +45,8 @@ CAMEMBERT_START_DOCSTRING = r""" Parameters: config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the configuration. + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ @@ -121,3 +121,18 @@ class CamembertForTokenClassification(RobertaForTokenClassification): config_class = CamembertConfig pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP + + +@add_start_docstrings( + """CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits` """, + CAMEMBERT_START_DOCSTRING, +) +class CamembertForQuestionAnswering(RobertaForQuestionAnswering): + """ + This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = CamembertConfig + pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP