diff --git a/docs/source/model_doc/ctrl.rst b/docs/source/model_doc/ctrl.rst index 4d5be2c882..94b7a61ca7 100644 --- a/docs/source/model_doc/ctrl.rst +++ b/docs/source/model_doc/ctrl.rst @@ -97,3 +97,8 @@ TFCTRLLMHeadModel .. autoclass:: transformers.TFCTRLLMHeadModel :members: call +TFCTRLForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFCTRLForSequenceClassification + :members: call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a12550db6f..e6d0fcbeb0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -756,6 +756,7 @@ if is_tf_available(): ) from .models.ctrl import ( TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel, TFCTRLPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 0ad0348007..c0cf193b52 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -53,7 +53,7 @@ from ..camembert.modeling_tf_camembert import ( TFCamembertForTokenClassification, TFCamembertModel, ) -from ..ctrl.modeling_tf_ctrl import TFCTRLLMHeadModel, TFCTRLModel +from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel from ..distilbert.modeling_tf_distilbert import ( TFDistilBertForMaskedLM, TFDistilBertForMultipleChoice, @@ -342,6 +342,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (GPT2Config, TFGPT2ForSequenceClassification), (MPNetConfig, TFMPNetForSequenceClassification), (OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification), + (CTRLConfig, TFCTRLForSequenceClassification), ] ) diff --git a/src/transformers/models/ctrl/__init__.py b/src/transformers/models/ctrl/__init__.py index 6308d386fa..734a7ef159 100644 --- a/src/transformers/models/ctrl/__init__.py +++ b/src/transformers/models/ctrl/__init__.py @@ -33,6 +33,7 @@ if is_torch_available(): if is_tf_available(): from .modeling_tf_ctrl import ( TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel, TFCTRLPreTrainedModel, diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index abbb5d0a57..9dceb3163e 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -19,11 +19,13 @@ import numpy as np import tensorflow as tf from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast +from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput from ...modeling_tf_utils import ( TFCausalLanguageModelingLoss, TFPreTrainedModel, + TFSequenceClassificationLoss, TFSharedEmbeddings, + get_initializer, input_processing, keras_serializable, shape_list, @@ -726,3 +728,160 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The CTRL Model transformer with a sequence classification head on top (linear layer). + + :class:`~transformers.TFCTRLForSequenceClassification` uses the last token in order to do the classification, as + other causal models (e.g. GPT-1, GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each + row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot + guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take + the last value in each row of the batch). + """, + CTRL_START_DOCSTRING, +) +class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.classifier = tf.keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + use_bias=False, + ) + self.transformer = TFCTRLMainLayer(config, name="transformer") + + def get_output_embeddings(self): + return self.transformer.w + + @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="ctrl", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids=None, + past=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + training=False, + **kwargs, + ): + r""" + labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., + config.vocab_size - 1]``. + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + past=past, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + training=training, + kwargs_call=kwargs, + ) + + transformer_outputs = self.transformer( + input_ids=inputs["input_ids"], + past=inputs["past"], + attention_mask=inputs["attention_mask"], + token_type_ids=inputs["token_type_ids"], + position_ids=inputs["position_ids"], + head_mask=inputs["head_mask"], + inputs_embeds=inputs["inputs_embeds"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + + hidden_states = transformer_outputs[0] + logits = self.classifier(hidden_states) + logits_shape = shape_list(logits) + in_logits = None + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if inputs["input_ids"] is not None: + sequence_lengths = ( + tf.reduce_sum( + tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), + -1, + keepdims=False, + ) + - 1 + ) + + def get_seq_element(sequence_position, input_batch): + return tf.strided_slice( + input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1] + ) + + result = tf.map_fn( + fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float" + ) + in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]]) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + loss = None + + if inputs["labels"] is not None: + if input_ids is not None: + batch_size, sequence_length = shape_list(inputs["input_ids"])[:2] + else: + batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2] + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + if not tf.is_tensor(sequence_lengths): + in_logits = logits[0:batch_size, sequence_lengths] + + loss = self.compute_loss( + tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]) + ) + + pooled_logits = in_logits if in_logits is not None else logits + + if not inputs["return_dict"]: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 9bd27d3825..5285b83cf7 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -429,6 +429,15 @@ class TFCamembertModel: TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None +class TFCTRLForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + class TFCTRLLMHeadModel: def __init__(self, *args, **kwargs): requires_tf(self) diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index 74dd95bfe2..4231c151bd 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -28,6 +28,7 @@ if is_tf_available(): from transformers.models.ctrl.modeling_tf_ctrl import ( TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel, ) @@ -61,6 +62,7 @@ class TFCTRLModelTester(object): self.num_labels = 3 self.num_choices = 4 self.scope = None + self.pad_token_id = self.vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -98,6 +100,7 @@ class TFCTRLModelTester(object): n_ctx=self.max_position_embeddings, # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -132,6 +135,20 @@ class TFCTRLModelTester(object): result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_ctrl_for_sequence_classification( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + config.num_labels = self.num_labels + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + inputs = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "labels": sequence_labels, + } + model = TFCTRLForSequenceClassification(config) + result = model(inputs) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -154,7 +171,7 @@ class TFCTRLModelTester(object): @require_tf class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): - all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else () + all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else () all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else () def setUp(self): @@ -172,6 +189,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_ctrl_lm_head(*config_and_inputs) + def test_ctrl_sequence_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_ctrl_for_sequence_classification(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: