From 41c559415ae72fc84afdf4150981e2db025b9c61 Mon Sep 17 00:00:00 2001 From: tucan9389 <37643248+tucan9389@users.noreply.github.com> Date: Tue, 31 Aug 2021 19:19:04 +0900 Subject: [PATCH] Add GPT2ForTokenClassification (#13290) * Add GPT2ForTokenClassification * Fix dropout exception for GPT2 NER * Remove sequence label in test * Change TokenClassifierOutput to TokenClassifierOutputWithPast * Fix for black formatter * Remove dummy * Update docs for GPT2ForTokenClassification * Fix check_inits ci fail * Update dummy_pt_objects after make fix-copies * Remove TokenClassifierOutputWithPast * Fix tuple input issue Co-authored-by: danielsejong55@gmail.com --- docs/source/model_doc/gpt2.rst | 7 ++ src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/gpt2/__init__.py | 2 + src/transformers/models/gpt2/modeling_gpt2.py | 103 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 9 ++ tests/test_modeling_gpt2.py | 17 ++- 7 files changed, 140 insertions(+), 1 deletion(-) diff --git a/docs/source/model_doc/gpt2.rst b/docs/source/model_doc/gpt2.rst index 56fb564bfd..4d1734da7d 100644 --- a/docs/source/model_doc/gpt2.rst +++ b/docs/source/model_doc/gpt2.rst @@ -108,6 +108,13 @@ GPT2ForSequenceClassification :members: forward +GPT2ForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPT2ForTokenClassification + :members: forward + + TFGPT2Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 623b8d610d..7f8cdb3c1f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -807,6 +807,7 @@ if is_torch_available(): "GPT2_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT2DoubleHeadsModel", "GPT2ForSequenceClassification", + "GPT2ForTokenClassification", "GPT2LMHeadModel", "GPT2Model", "GPT2PreTrainedModel", @@ -2460,6 +2461,7 @@ if TYPE_CHECKING: GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, + GPT2ForTokenClassification, GPT2LMHeadModel, GPT2Model, GPT2PreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6e44f4d954..26680c0cb2 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -399,6 +399,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("mpnet", "MPNetForTokenClassification"), ("deberta", "DebertaForTokenClassification"), ("deberta-v2", "DebertaV2ForTokenClassification"), + ("gpt2", "GPT2ForTokenClassification"), ("ibert", "IBertForTokenClassification"), ] ) diff --git a/src/transformers/models/gpt2/__init__.py b/src/transformers/models/gpt2/__init__.py index 1df4f9daf6..7169ddc63f 100644 --- a/src/transformers/models/gpt2/__init__.py +++ b/src/transformers/models/gpt2/__init__.py @@ -34,6 +34,7 @@ if is_torch_available(): "GPT2_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT2DoubleHeadsModel", "GPT2ForSequenceClassification", + "GPT2ForTokenClassification", "GPT2LMHeadModel", "GPT2Model", "GPT2PreTrainedModel", @@ -66,6 +67,7 @@ if TYPE_CHECKING: GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, + GPT2ForTokenClassification, GPT2LMHeadModel, GPT2Model, GPT2PreTrainedModel, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 4642522cbb..43419e6615 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_utils import ( Conv1D, @@ -1331,3 +1332,105 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="microsoft/DialogRPT-updown", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c99a5c9876..ad62c51006 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1781,6 +1781,15 @@ class GPT2ForSequenceClassification: requires_backends(cls, ["torch"]) +class GPT2ForTokenClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class GPT2LMHeadModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 94abcbfecf..a770729303 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -32,6 +32,7 @@ if is_torch_available(): GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, + GPT2ForTokenClassification, GPT2LMHeadModel, GPT2Model, GPT2Tokenizer, @@ -366,6 +367,16 @@ class GPT2ModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def create_and_check_gpt2_for_token_classification( + self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args + ): + config.num_labels = self.num_labels + model = GPT2ForTokenClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -394,7 +405,7 @@ class GPT2ModelTester: class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( - (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification) + (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, GPT2ForTokenClassification) if is_torch_available() else () ) @@ -462,6 +473,10 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs) + def test_gpt2_token_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs) + def test_gpt2_gradient_checkpointing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)