diff --git a/docs/source/model_doc/flaubert.rst b/docs/source/model_doc/flaubert.rst index c542fc0507..4a0e4ca581 100644 --- a/docs/source/model_doc/flaubert.rst +++ b/docs/source/model_doc/flaubert.rst @@ -61,6 +61,13 @@ FlaubertForSequenceClassification :members: +FlaubertForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaubertForTokenClassification + :members: + + FlaubertForQuestionAnsweringSimple ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -114,4 +121,4 @@ TFFlaubertForQuestionAnsweringSimple ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFFlaubertForQuestionAnsweringSimple - :members: \ No newline at end of file + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e70ec00247..a28387aed4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -353,6 +353,7 @@ if is_torch_available(): FlaubertModel, FlaubertWithLMHeadModel, FlaubertForSequenceClassification, + FlaubertForTokenClassification, FlaubertForQuestionAnswering, FlaubertForQuestionAnsweringSimple, FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 15e55cb866..7991254034 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -100,6 +100,7 @@ from .modeling_encoder_decoder import EncoderDecoderModel from .modeling_flaubert import ( FlaubertForQuestionAnsweringSimple, FlaubertForSequenceClassification, + FlaubertForTokenClassification, FlaubertModel, FlaubertWithLMHeadModel, ) @@ -326,6 +327,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( [ (DistilBertConfig, DistilBertForTokenClassification), (CamembertConfig, CamembertForTokenClassification), + (FlaubertConfig, FlaubertForTokenClassification), (XLMConfig, XLMForTokenClassification), (XLMRobertaConfig, XLMRobertaForTokenClassification), (LongformerConfig, LongformerForTokenClassification), @@ -1552,6 +1554,7 @@ class AutoModelForTokenClassification: - isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model) - isInstance of `albert` configuration class: :class:`~transformers.AlbertForTokenClassification` (AlBert model) - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model) + - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForTokenClassification` (Flaubert model) - isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model) - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model) - isInstance of `electra` configuration class: :class:`~transformers.ElectraForTokenClassification` (Electra model) @@ -1589,6 +1592,7 @@ class AutoModelForTokenClassification: - `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model) - `bert`: :class:`~transformers.BertForTokenClassification` (Bert model) - `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model) + - `flaubert`: :class:`~transformers.FlaubertForTokenClassification` (Flaubert model) - `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model) - `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model) diff --git a/src/transformers/modeling_flaubert.py b/src/transformers/modeling_flaubert.py index c4ef66b8bf..aeda892f7f 100644 --- a/src/transformers/modeling_flaubert.py +++ b/src/transformers/modeling_flaubert.py @@ -28,6 +28,7 @@ from .modeling_xlm import ( XLMForQuestionAnswering, XLMForQuestionAnsweringSimple, XLMForSequenceClassification, + XLMForTokenClassification, XLMModel, XLMWithLMHeadModel, get_masks, @@ -326,6 +327,25 @@ class FlaubertForSequenceClassification(XLMForSequenceClassification): self.init_weights() +@add_start_docstrings( + """Flaubert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + FLAUBERT_START_DOCSTRING, +) +class FlaubertForTokenClassification(XLMForTokenClassification): + """ + This class overrides :class:`~transformers.XLMForTokenClassification`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = FlaubertConfig + + def __init__(self, config): + super().__init__(config) + self.transformer = FlaubertModel(config) + self.init_weights() + + @add_start_docstrings( """Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, diff --git a/tests/test_modeling_flaubert.py b/tests/test_modeling_flaubert.py index 03e4af6ad8..af2918cb94 100644 --- a/tests/test_modeling_flaubert.py +++ b/tests/test_modeling_flaubert.py @@ -31,6 +31,7 @@ if is_torch_available(): FlaubertForQuestionAnswering, FlaubertForQuestionAnsweringSimple, FlaubertForSequenceClassification, + FlaubertForTokenClassification, ) from transformers.modeling_flaubert import FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST @@ -294,6 +295,30 @@ class FlaubertModelTester(object): self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]) + def create_and_check_flaubert_token_classif( + self, + config, + input_ids, + token_type_ids, + input_lengths, + sequence_labels, + token_labels, + is_impossible_labels, + input_mask, + ): + config.num_labels = self.num_labels + model = FlaubertForTokenClassification(config) + model.to(torch_device) + model.eval() + + loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -320,6 +345,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): FlaubertForQuestionAnswering, FlaubertForQuestionAnsweringSimple, FlaubertForSequenceClassification, + FlaubertForTokenClassification, ) if is_torch_available() else () @@ -352,6 +378,10 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_flaubert_sequence_classif(*config_and_inputs) + def test_flaubert_token_classif(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_flaubert_token_classif(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: