From 8b240a06617455eae59e1116af6a1a016664e963 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 13 Oct 2021 00:10:34 +0200 Subject: [PATCH] Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222) * Add cross attentions to TFGPT2Model * Add TFEncoderDecoderModel * Add TFBaseModelOutputWithPoolingAndCrossAttentions * Add cross attentions to TFBertModel * Fix past or past_key_values argument issue * Fix generation * Fix save and load * Add some checks and comments * Clean the code that deals with past keys/values * Add kwargs to processing_inputs * Add serving_output to TFEncoderDecoderModel * Some cleaning + fix use_cache value issue * Fix tests + add bert2bert/bert2gpt2 tests * Fix more tests * Ignore crossattention.bias when loading GPT2 weights into TFGPT2 * Fix return_dict_in_generate in tf generation * Fix is_token_logit_eos_token bug in tf generation * Finalize the tests after fixing some bugs * Fix another is_token_logit_eos_token bug in tf generation * Add/Update docs * Add TFBertEncoderDecoderModelTest * Clean test script * Add TFEncoderDecoderModel to the library * Add cross attentions to TFRobertaModel * Add TFRobertaEncoderDecoderModelTest * make style * Change the way of position_ids computation * bug fix * Fix copies in tf_albert * Remove some copied from and apply some fix-copies * Remove some copied * Add cross attentions to some other TF models * Remove encoder_hidden_states from TFLayoutLMModel.call for now * Make style * Fix TFRemBertForCausalLM * Revert the change to longformer + Remove copies * Revert the change to albert and convbert + Remove copies * make quality * make style * Add TFRembertEncoderDecoderModelTest * make quality and fix-copies * test TFRobertaForCausalLM * Fixes for failed tests * Fixes for failed tests * fix more tests * Fixes for failed tests * Fix Auto mapping order * Fix TFRemBertEncoder return value * fix tf_rembert * Check copies are OK * Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined * Add TFEncoderDecoderModelSaveLoadTests * fix tf weight loading * check the change of use_cache * Revert the change * Add missing test_for_causal_lm for TFRobertaModelTest * Try cleaning past * fix _reorder_cache * Revert some files to original versions * Keep as many copies as possible * Apply suggested changes - Use raise ValueError instead of assert * Move import to top * Fix wrong require_torch * Replace more assert by raise ValueError * Add test_pt_tf_model_equivalence (the test won't pass for now) * add test for loading/saving * finish * finish * Remove test_pt_tf_model_equivalence * Update tf modeling template * Remove pooling, added in the prev. commit, from MainLayer * Update tf modeling test template * Move inputs["use_cache"] = False to modeling_tf_utils.py * Fix torch.Tensor in the comment * fix use_cache * Fix missing use_cache in ElectraConfig * Add a note to from_pretrained * Fix style * Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt * Fix TFMLP (in TFGPT2) activation issue * Fix None past_key_values value in serving_output * Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub * Apply review suggestions - style for cross_attns in serving_output * Apply review suggestions - change assert + docstrings * break the error message to respect the char limit * deprecate the argument past * fix docstring style * Update the encoder-decoder rst file * fix Unknown interpreted text role "method" * fix typo Co-authored-by: ydshieh Co-authored-by: Patrick von Platen --- docs/source/index.rst | 2 +- docs/source/main_classes/output.rst | 21 + docs/source/model_doc/encoderdecoder.rst | 26 + docs/source/model_doc/roberta.rst | 7 + src/transformers/__init__.py | 4 + .../convert_pytorch_checkpoint_to_tf2.py | 2 + src/transformers/generation_tf_utils.py | 21 +- src/transformers/modeling_tf_outputs.py | 91 +++ src/transformers/modeling_tf_utils.py | 15 +- .../models/albert/modeling_tf_albert.py | 6 +- .../models/auto/modeling_tf_auto.py | 2 + .../models/bert/modeling_tf_bert.py | 369 ++++++++- .../models/convbert/modeling_tf_convbert.py | 5 +- .../models/electra/configuration_electra.py | 5 + .../models/electra/modeling_tf_electra.py | 333 +++++++- .../models/encoder_decoder/__init__.py | 8 +- .../modeling_tf_encoder_decoder.py | 647 +++++++++++++++ .../models/layoutlm/modeling_tf_layoutlm.py | 185 ++++- .../longformer/modeling_tf_longformer.py | 18 +- .../mobilebert/modeling_tf_mobilebert.py | 1 - .../models/mpnet/modeling_tf_mpnet.py | 1 - .../models/openai/modeling_tf_openai.py | 1 - .../models/rembert/modeling_tf_rembert.py | 373 ++++++++- src/transformers/models/roberta/__init__.py | 2 + .../models/roberta/modeling_tf_roberta.py | 474 ++++++++++- .../xlm_roberta/modeling_tf_xlm_roberta.py | 14 + src/transformers/utils/dummy_tf_objects.py | 18 + ...tf_{{cookiecutter.lowercase_modelname}}.py | 376 ++++++++- ...tf_{{cookiecutter.lowercase_modelname}}.py | 29 +- tests/test_modeling_tf_bert.py | 29 +- tests/test_modeling_tf_common.py | 16 + tests/test_modeling_tf_encoder_decoder.py | 765 ++++++++++++++++++ tests/test_modeling_tf_rembert.py | 29 +- tests/test_modeling_tf_roberta.py | 42 +- utils/check_repo.py | 2 + 35 files changed, 3738 insertions(+), 201 deletions(-) create mode 100644 src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py create mode 100644 tests/test_modeling_tf_encoder_decoder.py diff --git a/docs/source/index.rst b/docs/source/index.rst index bd0ca334c6..c4ce38f105 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -379,7 +379,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ | +| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/main_classes/output.rst b/docs/source/main_classes/output.rst index 5d0bdc7bc6..be1ecca3cf 100644 --- a/docs/source/main_classes/output.rst +++ b/docs/source/main_classes/output.rst @@ -210,6 +210,13 @@ TFBaseModelOutputWithPooling :members: +TFBaseModelOutputWithPoolingAndCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndCrossAttentions + :members: + + TFBaseModelOutputWithPast ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -217,6 +224,13 @@ TFBaseModelOutputWithPast :members: +TFBaseModelOutputWithPastAndCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPastAndCrossAttentions + :members: + + TFSeq2SeqModelOutput ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -231,6 +245,13 @@ TFCausalLMOutput :members: +TFCausalLMOutputWithCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions + :members: + + TFCausalLMOutputWithPast ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/encoderdecoder.rst b/docs/source/model_doc/encoderdecoder.rst index 02b902ecc5..5b6759528a 100644 --- a/docs/source/model_doc/encoderdecoder.rst +++ b/docs/source/model_doc/encoderdecoder.rst @@ -27,6 +27,25 @@ An application of this architecture could be to leverage two pretrained :class:` and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders `__ by Yang Liu and Mirella Lapata. +The :meth:`~transformers.TFEncoderDecoderModel.from_pretrained` currently doesn't support initializing the model from a +pytorch checkpoint. Passing ``from_pt=True`` to this method will throw an exception. If there are only pytorch +checkpoints for a particular encoder-decoder model, a workaround is: + +.. code-block:: + + >>> # a workaround to load from pytorch checkpoint + >>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + >>> _model.encoder.save_pretrained("./encoder") + >>> _model.decoder.save_pretrained("./decoder") + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True + ... ) + >>> # This is only for copying some specific attributes of this particular model. + >>> model.config = _model.config + +This model was contributed by `thomwolf `__. This model's TensorFlow and Flax versions +were contributed by `ydshieh `__. + EncoderDecoderConfig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -42,6 +61,13 @@ EncoderDecoderModel :members: forward, from_encoder_decoder_pretrained +TFEncoderDecoderModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFEncoderDecoderModel + :members: call, from_encoder_decoder_pretrained + + FlaxEncoderDecoderModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/roberta.rst b/docs/source/model_doc/roberta.rst index f1eac9c173..b9d533e804 100644 --- a/docs/source/model_doc/roberta.rst +++ b/docs/source/model_doc/roberta.rst @@ -126,6 +126,13 @@ TFRobertaModel :members: call +TFRobertaForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFRobertaForCausalLM + :members: call + + TFRobertaForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 44fe890fe3..a344669d72 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1444,6 +1444,7 @@ if is_tf_available(): "TFElectraPreTrainedModel", ] ) + _import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel") _import_structure["models.flaubert"].extend( [ "TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1596,6 +1597,7 @@ if is_tf_available(): _import_structure["models.roberta"].extend( [ "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFRobertaForCausalLM", "TFRobertaForMaskedLM", "TFRobertaForMultipleChoice", "TFRobertaForQuestionAnswering", @@ -3096,6 +3098,7 @@ if TYPE_CHECKING: TFElectraModel, TFElectraPreTrainedModel, ) + from .models.encoder_decoder import TFEncoderDecoderModel from .models.flaubert import ( TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFFlaubertForMultipleChoice, @@ -3205,6 +3208,7 @@ if TYPE_CHECKING: ) from .models.roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFRobertaForCausalLM, TFRobertaForMaskedLM, TFRobertaForMultipleChoice, TFRobertaForQuestionAnswering, diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index 315afeccd9..bcf69be478 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -76,6 +76,7 @@ from . import ( TFLxmertForPreTraining, TFLxmertVisualFeatureEncoder, TFOpenAIGPTLMHeadModel, + TFRobertaForCausalLM, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TFT5ForConditionalGeneration, @@ -215,6 +216,7 @@ MODEL_CLASSES = { ), "roberta": ( RobertaConfig, + TFRobertaForCausalLM, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 0652516674..d91ff8ce6f 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -650,7 +650,11 @@ class TFGenerationMixin: # current position and vocab size cur_len = shape_list(input_ids)[1] # unused - vocab_size = self.config.vocab_size + vocab_size = getattr(self.config, "vocab_size", None) + if vocab_size is None and self.config.is_encoder_decoder: + decoder_config = getattr(self.config, "decoder", None) + if decoder_config is not None: + vocab_size = getattr(self.config.decoder, "vocab_size", None) # set effective batch size and effective batch multiplier according to do_sample if do_sample: @@ -678,6 +682,7 @@ class TFGenerationMixin: attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict_in_generate, ) if return_dict_in_generate: if output_attentions: @@ -911,7 +916,7 @@ class TFGenerationMixin: if eos_token_id is not None and cur_len < min_length: # create eos_token_id boolean mask is_token_logit_eos_token = tf.convert_to_tensor( - [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool + [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool ) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) @@ -1142,7 +1147,7 @@ class TFGenerationMixin: num_batch_hypotheses = batch_size * num_beams is_token_logit_eos_token = tf.convert_to_tensor( - [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool + [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool ) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size]) @@ -1446,11 +1451,17 @@ class TFGenerationMixin: Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in the generate method. """ + vocab_size = getattr(self.config, "vocab_size", None) + if vocab_size is None and self.config.is_encoder_decoder: + decoder_config = getattr(self.config, "decoder", None) + if decoder_config is not None: + vocab_size = getattr(self.config.decoder, "vocab_size", None) + if cur_len == 1 and forced_bos_token_id is not None: - vocab_range = tf.constant(range(self.config.vocab_size)) + vocab_range = tf.constant(range(vocab_size)) return tf.where(vocab_range != forced_bos_token_id, -1e8, logits) elif cur_len == max_length - 1 and forced_eos_token_id is not None: - vocab_range = tf.constant(range(self.config.vocab_size)) + vocab_range = tf.constant(range(vocab_size)) return tf.where(vocab_range != forced_eos_token_id, -1e8, logits) else: return logits diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index fefc65ec9b..123ae76db7 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -80,6 +80,54 @@ class TFBaseModelOutputWithPooling(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + + This output is usually *not* a good summary of the semantic content of the input, you're often better with + averaging or pooling the sequence of hidden-states for the whole input sequence. + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, + num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor = None + pooler_output: tf.Tensor = None + past_key_values: Optional[List[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFBaseModelOutputWithPast(ModelOutput): """ @@ -317,6 +365,49 @@ class TFCausalLMOutputWithPast(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFCausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (:obj:`tf.Tensor` of shape :obj:`(n,)`, `optional`, where n is the number of non-masked labels, returned when :obj:`labels` is provided): + Language modeling loss (for next-token prediction). + logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, + num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[tf.Tensor] = None + logits: tf.Tensor = None + past_key_values: Optional[List[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFMaskedLMOutput(ModelOutput): """ diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 80291c9305..2692949aa8 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -296,7 +296,9 @@ def booleans_processing(config, **kwargs): ) if "use_cache" in kwargs: - final_booleans["use_cache"] = kwargs["use_cache"] if kwargs["use_cache"] is not None else config.use_cache + final_booleans["use_cache"] = ( + kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None) + ) else: if ( kwargs["output_attentions"] not in (None, config.output_attentions) @@ -318,7 +320,7 @@ def booleans_processing(config, **kwargs): final_booleans["return_dict"] = True if "use_cache" in kwargs: - final_booleans["use_cache"] = config.use_cache + final_booleans["use_cache"] = getattr(config, "use_cache", None) return final_booleans @@ -362,6 +364,15 @@ def input_processing(func, config, input_ids, **kwargs): ) output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") + if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs: + warnings.warn( + "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past") + elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs: + kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values") + if len(kwargs["kwargs_call"]) > 0: raise ValueError( f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}." diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index 8f7ef5c9ef..ba54f36940 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -157,6 +157,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): position_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None, inputs_embeds: tf.Tensor = None, + past_key_values_length=0, training: bool = False, ) -> tf.Tensor: """ @@ -177,7 +178,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): token_type_ids = tf.fill(dims=input_shape, value=0) if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) @@ -829,7 +832,6 @@ class TFAlbertModel(TFAlbertPreTrainedModel): return outputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 19a97532be..f7ff78e629 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -133,6 +133,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( # Model for Causal LM mapping ("rembert", "TFRemBertForCausalLM"), ("roformer", "TFRoFormerForCausalLM"), + ("roberta", "TFRobertaForCausalLM"), ("bert", "TFBertLMHeadModel"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"), @@ -181,6 +182,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("blenderbot", "TFBlenderbotForConditionalGeneration"), ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), ("bart", "TFBartForConditionalGeneration"), + ("encoder-decoder", "TFEncoderDecoderModel"), ] ) diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 7e01627311..3791cd1c04 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -25,6 +25,7 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...file_utils import ( + DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS, ModelOutput, add_code_sample_docstrings, @@ -33,9 +34,9 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFCausalLMOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, TFMaskedLMOutput, TFMultipleChoiceModelOutput, TFNextSentencePredictorOutput, @@ -174,6 +175,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): position_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None, inputs_embeds: tf.Tensor = None, + past_key_values_length=0, training: bool = False, ) -> tf.Tensor: """ @@ -194,7 +196,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer): token_type_ids = tf.fill(dims=input_shape, value=0) if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) @@ -232,6 +236,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) @@ -244,16 +250,49 @@ class TFBertSelfAttention(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) + value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) @@ -283,6 +322,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer): attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -319,6 +360,9 @@ class TFBertAttention(tf.keras.layers.Layer): input_tensor: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: @@ -326,13 +370,17 @@ class TFBertAttention(tf.keras.layers.Layer): hidden_states=input_tensor, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] return outputs @@ -380,6 +428,12 @@ class TFBertLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.attention = TFBertAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFBertAttention(config, name="crossattention") self.intermediate = TFBertIntermediate(config, name="intermediate") self.bert_output = TFBertOutput(config, name="output") @@ -388,22 +442,69 @@ class TFBertLayer(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: - attention_outputs = self.attention( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( input_tensor=hidden_states, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, output_attentions=output_attentions, training=training, ) - attention_output = attention_outputs[0] + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers " + "by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + intermediate_output = self.intermediate(hidden_states=attention_output) layer_output = self.bert_output( hidden_states=intermediate_output, input_tensor=attention_output, training=training ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) return outputs @@ -411,7 +512,7 @@ class TFBertLayer(tf.keras.layers.Layer): class TFBertEncoder(tf.keras.layers.Layer): def __init__(self, config: BertConfig, **kwargs): super().__init__(**kwargs) - + self.config = config self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] def call( @@ -419,39 +520,61 @@ class TFBertEncoder(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]], + use_cache: Optional[bool], output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_value = past_key_values[i] if past_key_values is not None else None + layer_outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -579,6 +702,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.config = config + self.is_decoder = config.is_decoder self.embeddings = TFBertEmbeddings(config, name="embeddings") self.encoder = TFBertEncoder(config, name="encoder") @@ -606,12 +730,16 @@ class TFBertMainLayer(tf.keras.layers.Layer): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: inputs = input_processing( func=self.call, config=self.config, @@ -621,6 +749,10 @@ class TFBertMainLayer(tf.keras.layers.Layer): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -628,6 +760,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): kwargs_call=kwargs, ) + if not self.config.is_decoder: + inputs["use_cache"] = False + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: @@ -637,8 +772,16 @@ class TFBertMainLayer(tf.keras.layers.Layer): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + + if inputs["past_key_values"] is None: + past_key_values_length = 0 + inputs["past_key_values"] = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) + inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) if inputs["token_type_ids"] is None: inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) @@ -648,6 +791,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): position_ids=inputs["position_ids"], token_type_ids=inputs["token_type_ids"], inputs_embeds=inputs["inputs_embeds"], + past_key_values_length=past_key_values_length, training=inputs["training"], ) @@ -656,7 +800,29 @@ class TFBertMainLayer(tf.keras.layers.Layer): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) + attention_mask_shape = shape_list(inputs["attention_mask"]) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) + extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + else: + extended_attention_mask = tf.reshape( + inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -668,6 +834,29 @@ class TFBertMainLayer(tf.keras.layers.Layer): ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and inputs["encoder_attention_mask"] is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + inputs["encoder_attention_mask"] = tf.cast( + inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + ) + num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -682,6 +871,10 @@ class TFBertMainLayer(tf.keras.layers.Layer): hidden_states=embedding_output, attention_mask=extended_attention_mask, head_mask=inputs["head_mask"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -697,11 +890,13 @@ class TFBertMainLayer(tf.keras.layers.Layer): pooled_output, ) + encoder_outputs[1:] - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -714,6 +909,24 @@ class TFBertPreTrainedModel(TFPreTrainedModel): config_class = BertConfig base_model_prefix = "bert" + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + @dataclass class TFBertForPreTrainingOutput(ModelOutput): @@ -853,7 +1066,7 @@ class TFBertModel(TFBertPreTrainedModel): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -864,12 +1077,36 @@ class TFBertModel(TFBertPreTrainedModel): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation + """ inputs = input_processing( func=self.call, config=self.config, @@ -879,6 +1116,10 @@ class TFBertModel(TFBertPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -892,6 +1133,10 @@ class TFBertModel(TFBertPreTrainedModel): position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -900,15 +1145,24 @@ class TFBertModel(TFBertPreTrainedModel): return outputs - def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=output.last_hidden_state, pooler_output=output.pooler_output, + past_key_values=pkv, hidden_states=hs, attentions=attns, + cross_attentions=cross_attns, ) @@ -960,11 +1214,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): **kwargs, ) -> Union[TFBertForPreTrainingOutput, Tuple[tf.Tensor]]: r""" - labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): + labels (:obj:`tf.Tensor` of shape ``(batch_size, sequence_length)``, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + next_sentence_label (``tf.Tensor`` of shape ``(batch_size,)``, `optional`): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: @@ -1184,10 +1438,22 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + # cut decoder_input_ids if past is used + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": model_kwargs["use_cache"], + } + @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutput, + output_type=TFCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -1198,14 +1464,36 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., config.vocab_size - 1]``. @@ -1219,6 +1507,10 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1233,6 +1525,10 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -1252,18 +1548,27 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return TFCausalLMOutput( + return TFCausalLMOutputWithCrossAttentions( loss=loss, logits=logits, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) - def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: + def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + return TFCausalLMOutputWithCrossAttentions( + logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns + ) @add_start_docstrings( diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index 85eadc7a7f..0cd7e6fa02 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -110,6 +110,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer): position_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None, inputs_embeds: tf.Tensor = None, + past_key_values_length=0, training: bool = False, ) -> tf.Tensor: """ @@ -130,7 +131,9 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer): token_type_ids = tf.fill(dims=input_shape, value=0) if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) diff --git a/src/transformers/models/electra/configuration_electra.py b/src/transformers/models/electra/configuration_electra.py index f4dd18bf27..b0fb6ea73c 100644 --- a/src/transformers/models/electra/configuration_electra.py +++ b/src/transformers/models/electra/configuration_electra.py @@ -104,6 +104,9 @@ class ElectraConfig(PretrainedConfig): `__. For more information on :obj:`"relative_key_query"`, please refer to `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `__. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. classifier_dropout (:obj:`float`, `optional`): The dropout ratio for the classification head. @@ -143,6 +146,7 @@ class ElectraConfig(PretrainedConfig): summary_last_dropout=0.1, pad_token_id=0, position_embedding_type="absolute", + use_cache=True, classifier_dropout=None, **kwargs ): @@ -167,4 +171,5 @@ class ElectraConfig(PretrainedConfig): self.summary_activation = summary_activation self.summary_last_dropout = summary_last_dropout self.position_embedding_type = position_embedding_type + self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index c862b3664f..cd03c997a2 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -23,6 +23,7 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...file_utils import ( + DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS, ModelOutput, add_code_sample_docstrings, @@ -31,7 +32,7 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import ( - TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, TFMaskedLMOutput, TFMultipleChoiceModelOutput, TFQuestionAnsweringModelOutput, @@ -99,6 +100,8 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) @@ -111,16 +114,49 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) + value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) @@ -150,6 +186,8 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -188,6 +226,9 @@ class TFElectraAttention(tf.keras.layers.Layer): input_tensor: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: @@ -195,13 +236,17 @@ class TFElectraAttention(tf.keras.layers.Layer): hidden_states=input_tensor, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] return outputs @@ -252,6 +297,12 @@ class TFElectraLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.attention = TFElectraAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFElectraAttention(config, name="crossattention") self.intermediate = TFElectraIntermediate(config, name="intermediate") self.bert_output = TFElectraOutput(config, name="output") @@ -260,22 +311,69 @@ class TFElectraLayer(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: - attention_outputs = self.attention( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( input_tensor=hidden_states, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, output_attentions=output_attentions, training=training, ) - attention_output = attention_outputs[0] + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers " + "by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + intermediate_output = self.intermediate(hidden_states=attention_output) layer_output = self.bert_output( hidden_states=intermediate_output, input_tensor=attention_output, training=training ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) return outputs @@ -284,7 +382,7 @@ class TFElectraLayer(tf.keras.layers.Layer): class TFElectraEncoder(tf.keras.layers.Layer): def __init__(self, config: ElectraConfig, **kwargs): super().__init__(**kwargs) - + self.config = config self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] def call( @@ -292,39 +390,61 @@ class TFElectraEncoder(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]], + use_cache: Optional[bool], output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_value = past_key_values[i] if past_key_values is not None else None + layer_outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -396,6 +516,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): position_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None, inputs_embeds: tf.Tensor = None, + past_key_values_length=0, training: bool = False, ) -> tf.Tensor: """ @@ -416,7 +537,9 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): token_type_ids = tf.fill(dims=input_shape, value=0) if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) @@ -471,6 +594,25 @@ class TFElectraPreTrainedModel(TFPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"] _keys_to_ignore_on_load_missing = [r"dropout"] + @property + # Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel.dummy_inputs + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + @keras_serializable class TFElectraMainLayer(tf.keras.layers.Layer): @@ -480,13 +622,14 @@ class TFElectraMainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.config = config + self.is_decoder = config.is_decoder + self.embeddings = TFElectraEmbeddings(config, name="embeddings") if config.embedding_size != config.hidden_size: self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project") self.encoder = TFElectraEncoder(config, name="encoder") - self.config = config def get_input_embeddings(self): return self.embeddings @@ -502,24 +645,50 @@ class TFElectraMainLayer(tf.keras.layers.Layer): """ raise NotImplementedError - def get_extended_attention_mask(self, attention_mask, input_shape, dtype): + def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0): + batch_size, seq_length = input_shape + if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype) + one_cst = tf.constant(1.0, dtype=dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) return extended_attention_mask @@ -539,6 +708,10 @@ class TFElectraMainLayer(tf.keras.layers.Layer): position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -554,6 +727,10 @@ class TFElectraMainLayer(tf.keras.layers.Layer): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -561,6 +738,9 @@ class TFElectraMainLayer(tf.keras.layers.Layer): kwargs_call=kwargs, ) + if not self.config.is_decoder: + inputs["use_cache"] = False + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: @@ -570,34 +750,71 @@ class TFElectraMainLayer(tf.keras.layers.Layer): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + + if inputs["past_key_values"] is None: + past_key_values_length = 0 + inputs["past_key_values"] = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(input_shape, 1) + inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) if inputs["token_type_ids"] is None: - inputs["token_type_ids"] = tf.fill(input_shape, 0) + inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) hidden_states = self.embeddings( - inputs["input_ids"], - inputs["position_ids"], - inputs["token_type_ids"], - inputs["inputs_embeds"], + input_ids=inputs["input_ids"], + position_ids=inputs["position_ids"], + token_type_ids=inputs["token_type_ids"], + inputs_embeds=inputs["inputs_embeds"], + past_key_values_length=past_key_values_length, training=inputs["training"], ) extended_attention_mask = self.get_extended_attention_mask( - inputs["attention_mask"], input_shape, hidden_states.dtype + inputs["attention_mask"], input_shape, hidden_states.dtype, past_key_values_length ) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and inputs["encoder_attention_mask"] is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + inputs["encoder_attention_mask"] = tf.cast( + inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + ) + num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + inputs["head_mask"] = self.get_head_mask(inputs["head_mask"]) if hasattr(self, "embeddings_project"): hidden_states = self.embeddings_project(hidden_states, training=inputs["training"]) hidden_states = self.encoder( - hidden_states, - extended_attention_mask, - inputs["head_mask"], - inputs["output_attentions"], - inputs["output_hidden_states"], - inputs["return_dict"], + hidden_states=hidden_states, + attention_mask=extended_attention_mask, + head_mask=inputs["head_mask"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], training=inputs["training"], ) @@ -735,7 +952,7 @@ class TFElectraModel(TFElectraPreTrainedModel): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, + output_type=TFBaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -746,12 +963,36 @@ class TFElectraModel(TFElectraPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, training=False, **kwargs, ): + r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation + """ inputs = input_processing( func=self.call, config=self.config, @@ -761,6 +1002,10 @@ class TFElectraModel(TFElectraPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -773,6 +1018,10 @@ class TFElectraModel(TFElectraPreTrainedModel): token_type_ids=inputs["token_type_ids"], position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -782,12 +1031,22 @@ class TFElectraModel(TFElectraPreTrainedModel): return outputs - # Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output def serving_output(self, output): + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + cross_attentions=cross_attns, + ) @add_start_docstrings( diff --git a/src/transformers/models/encoder_decoder/__init__.py b/src/transformers/models/encoder_decoder/__init__.py index e6e36a8c7c..a5eafdf251 100644 --- a/src/transformers/models/encoder_decoder/__init__.py +++ b/src/transformers/models/encoder_decoder/__init__.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_flax_available, is_torch_available +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available _import_structure = { @@ -28,6 +28,9 @@ _import_structure = { if is_torch_available(): _import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"] +if is_tf_available(): + _import_structure["modeling_tf_encoder_decoder"] = ["TFEncoderDecoderModel"] + if is_flax_available(): _import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"] @@ -37,6 +40,9 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_encoder_decoder import EncoderDecoderModel + if is_tf_available(): + from .modeling_tf_encoder_decoder import TFEncoderDecoderModel + if is_flax_available(): from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py new file mode 100644 index 0000000000..c59e8c52f1 --- /dev/null +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -0,0 +1,647 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Classes to support TF Encoder-Decoder architectures """ + + +from typing import Optional + +import tensorflow as tf + +from ...configuration_utils import PretrainedConfig +from ...file_utils import ( + DUMMY_INPUTS, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput +from ...modeling_tf_utils import TFPreTrainedModel, input_processing +from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EncoderDecoderConfig" + +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the + encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via + :meth:`~transformers.TFAutoModel.from_pretrained` function and the decoder is loaded via + :meth:`~transformers.TFAutoModelForCausalLM.from_pretrained` function. Cross-attention layers are automatically + added to the decoder and should be fine-tuned on a downstream generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks + `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models + (see the examples for more information). + + This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a `tf.keras.Model `__ subclass. Use + it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage + and behavior. + + Parameters: + config (:class:`~transformers.EncoderDecoderConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the + model weights. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`np.ndarray`, :obj:`tf.Tensor`, :obj:`List[tf.Tensor]` :obj:`Dict[str, tf.Tensor]` or :obj:`Dict[str, np.ndarray]` and each example must have the shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + + If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + Provide for sequence to sequence training to the decoder. Indices can be obtained using + :class:`~transformers.PreTrainedTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and + :meth:`transformers.PreTrainedTokenizer.__call__` for details. + decoder_attention_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + encoder_outputs (:obj:`tuple(tuple(tf.Tensor)`, `optional`): + This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`) is a + tensor of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the + decoder. + past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`({0})`. + inputs_embeds (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids` + indices into associated vectors than the model's internal embedding lookup matrix. + labels (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`): + Labels for computing the masked language modeling loss for the decoder. Indices should be in ``[-100, 0, + ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a + plain tuple. + training (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + + - Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function. + - With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function. +""" + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) +class TFEncoderDecoderModel(TFPreTrainedModel): + r""" + :class:`~transformers.TFEncoderDecoder` is a generic model class that will be instantiated as a transformer + architecture with one of the base model classes of the library as encoder and another one as decoder when created + with the :meth`~transformers.TFAutoModel.from_pretrained` class method for the encoder and + :meth`~transformers.TFAutoModelForCausalLM.from_pretrained` class method for the decoder. + """ + config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + load_weight_prefix = "tf_encoder_decoder_model_1" + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[TFPreTrainedModel] = None, + decoder: Optional[TFPreTrainedModel] = None, + ): + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided") + if config is None: + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"config: {config} has to be of type {self.config_class}") + # initialize with config + super().__init__(config) + + if encoder is None: + encoder = TFAutoModel.from_config(config.encoder, name="encoder") + + if decoder is None: + decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder") + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + if self.encoder.get_output_embeddings() is not None: + raise ValueError("The encoder {} should not have a LM Head. Please use a model without LM Head") + + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + # Add `decoder_input_ids` because `self.decoder` requires it. + input_ids = tf.constant(DUMMY_INPUTS) + dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = input_ids.shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently. + + If there are only pytorch checkpoints for a particular encoder-decoder model, a workaround is:: + + >>> # a workaround to load from pytorch checkpoint + >>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + >>> _model.encoder.save_pretrained("./encoder") + >>> _model.decoder.save_pretrained("./decoder") + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True + ... ) + >>> # This is only for copying some specific attributes of this particular model. + >>> model.config = _model.config + + """ + + from_pt = kwargs.pop("from_pt", False) + if from_pt: + raise ValueError( + "Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently. " + "Use a tensorflow checkpoint instead. If only the pytorch checkpoints are available, " + "create the encoder and decoder models separately, and use them to initialize `TFEncoderDecoderModel`. " + "Check `TFEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details." + ) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: str = None, + decoder_pretrained_model_name_or_path: str = None, + *model_args, + **kwargs + ) -> TFPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + Params: + encoder_pretrained_model_name_or_path (:obj: `str`, `optional`): + Information necessary to initiate the encoder. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `pytorch index checkpoint file` (e.g, ``./pt_model/``). In this case, + ``encoder_from_pt`` should be set to :obj:`True`. + + decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `pytorch checkpoint file` (e.g, ``./pt_model/``). In this case, + ``decoder_from_pt`` should be set to :obj:`True`. + + model_args (remaining positional arguments, `optional`): + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. + + kwargs (remaining dictionary of keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). + + - To update the encoder configuration, use the prefix `encoder_` for each configuration parameter. + - To update the decoder configuration, use the prefix `decoder_` for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a :obj:`config` is provided or automatically loaded. + + Example:: + + >>> from transformers import TFEncoderDecoderModel + >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'gpt2') + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2gpt2") + >>> # load fine-tuned model + >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2") + + """ + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined" + ) + + if "config" not in kwargs_encoder: + + encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + kwargs_encoder["name"] = "encoder" + kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix + encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" + ) + + if "config" not in kwargs_decoder: + + decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + kwargs_decoder["name"] = "decoder" + kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix + decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly. + if encoder.name != "encoder": + raise ValueError("encoder model must be created with the name `encoder`.") + if decoder.name != "decoder": + raise ValueError("decoder model must be created with the name `decoder`.") + + # instantiate config with corresponding kwargs + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + return cls(encoder=encoder, decoder=decoder, config=config) + + @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + r""" + Returns: + + Examples:: + + >>> from transformers import TFEncoderDecoderModel, BertTokenizer + + >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2') + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + >>> # forward + >>> input_ids = tokenizer.encode("Hello, my dog is cute", add_special_tokens=True, return_tensors='tf') # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) + + >>> # training + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids) + >>> loss, logits = outputs.loss, outputs.logits + + >>> # save and load from pretrained + >>> model.save_pretrained("bert2gpt2") + >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2") + + >>> # generation + >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id) + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + + encoder_processing_inputs = { + "func": self.encoder.call, + "config": self.encoder.config, + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "training": training, + "kwargs_call": kwargs_encoder, + } + + # Add arguments to encoder from `kwargs_encoder` + for k, v in kwargs_encoder.items(): + encoder_processing_inputs[k] = v + kwargs_encoder = {} + + encoder_inputs = input_processing(**encoder_processing_inputs) + + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. + if "decoder_input_ids" in encoder_inputs: + decoder_input_ids = encoder_inputs.pop("decoder_input_ids") + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. + if "decoder_attention_mask" in encoder_inputs: + decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask") + + encoder_outputs = self.encoder(**encoder_inputs) + + encoder_hidden_states = encoder_outputs[0] + + decoder_processing_inputs = { + "func": self.decoder.call, + "config": self.decoder.config, + "input_ids": decoder_input_ids, + "attention_mask": decoder_attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": attention_mask, + "inputs_embeds": decoder_inputs_embeds, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "use_cache": use_cache, + "past_key_values": past_key_values, + "return_dict": return_dict, + "training": training, + "kwargs_call": kwargs_decoder, + } + + # Add arguments to decoder from `kwargs_decoder` + for k, v in kwargs_decoder.items(): + decoder_processing_inputs[k] = v + kwargs_decoder = {} + + decoder_inputs = input_processing(**decoder_processing_inputs) + decoder_outputs = self.decoder(**decoder_inputs) + + loss = None if decoder_inputs["labels"] is None else decoder_outputs[0] + logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1] + past_key_values = None + + if decoder_inputs["use_cache"]: + past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2] + # The starting index of the remaining elements in `decoder_outputs` + start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) + + past = (encoder_outputs[0], past_key_values) if past_key_values else None + + if not decoder_inputs["return_dict"]: + if not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs + output = tuple([x for x in output if x is not None]) + return output + + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + if not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + return TFSeq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + cross_attns = ( + tf.convert_to_tensor(output.cross_attentions) + if self.config.output_attentions and output.cross_attentions is not None + else None + ) + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + cross_attentions=cross_attns, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + use_cache=None, + **kwargs, + ): + if past is None or len(past) not in {1, 2}: + raise ValueError(f"past has to be an iterable of length 1,2 got {past}") + + if len(past) == 1: + if not isinstance(past[0], tf.Tensor): + raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}") + encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) + past_key_values = None + else: + if len(past) != 2: + raise ValueError( + "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." + ) + encoder_outputs, past_key_values = past + if isinstance(encoder_outputs, tuple): + if not isinstance(encoder_outputs[0], tf.Tensor): + raise ValueError( + f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" + ) + encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) + elif isinstance(encoder_outputs, tf.Tensor): + encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) + if not past_key_values: + raise ValueError( + f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + ) + decoder_input_ids = decoder_input_ids[:, -1:] + + if not isinstance(encoder_outputs, TFBaseModelOutput): + raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.") + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported." + "Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past, beam_idx): + # apply decoder cache reordering here + if len(past) == 1: + return past + + encoder_outputs, past_key_values = past + + return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx)) diff --git a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py index d17924f9f4..3f73cfb8ac 100644 --- a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py @@ -24,8 +24,8 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, TFMaskedLMOutput, TFSequenceClassifierOutput, TFTokenClassifierOutput, @@ -216,6 +216,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) @@ -228,16 +230,49 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) + value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) @@ -267,6 +302,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -305,6 +342,9 @@ class TFLayoutLMAttention(tf.keras.layers.Layer): input_tensor: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: @@ -312,13 +352,17 @@ class TFLayoutLMAttention(tf.keras.layers.Layer): hidden_states=input_tensor, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] return outputs @@ -369,6 +413,12 @@ class TFLayoutLMLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.attention = TFLayoutLMAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFLayoutLMAttention(config, name="crossattention") self.intermediate = TFLayoutLMIntermediate(config, name="intermediate") self.bert_output = TFLayoutLMOutput(config, name="output") @@ -377,22 +427,69 @@ class TFLayoutLMLayer(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: - attention_outputs = self.attention( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( input_tensor=hidden_states, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, output_attentions=output_attentions, training=training, ) - attention_output = attention_outputs[0] + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers " + "by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + intermediate_output = self.intermediate(hidden_states=attention_output) layer_output = self.bert_output( hidden_states=intermediate_output, input_tensor=attention_output, training=training ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) return outputs @@ -401,7 +498,7 @@ class TFLayoutLMLayer(tf.keras.layers.Layer): class TFLayoutLMEncoder(tf.keras.layers.Layer): def __init__(self, config: LayoutLMConfig, **kwargs): super().__init__(**kwargs) - + self.config = config self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] def call( @@ -409,39 +506,61 @@ class TFLayoutLMEncoder(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]], + use_cache: Optional[bool], output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_value = past_key_values[i] if past_key_values is not None else None + layer_outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -585,12 +704,14 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: inputs = input_processing( func=self.call, config=self.config, @@ -665,6 +786,11 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer): hidden_states=embedding_output, attention_mask=extended_attention_mask, head_mask=inputs["head_mask"], + # Need to pass these required positional arguments to `Encoder` + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + past_key_values=None, + use_cache=False, output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -680,11 +806,12 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer): pooled_output, ) + encoder_outputs[1:] - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -802,7 +929,9 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm") @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC + ) def call( self, input_ids: Optional[TFModelInputType] = None, @@ -812,12 +941,14 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: r""" Returns: @@ -859,6 +990,8 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -881,15 +1014,25 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): return outputs - def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=output.last_hidden_state, pooler_output=output.pooler_output, + past_key_values=pkv, hidden_states=hs, attentions=attns, + cross_attentions=cross_attns, ) diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index dfe620ffb6..7d7a881fc8 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -510,7 +510,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): super().build(input_shape) - def create_position_ids_from_input_ids(self, input_ids): + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. @@ -520,11 +520,19 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): Returns: tf.Tensor """ mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = tf.math.cumsum(mask, axis=1) * mask + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask return incremental_indices + self.padding_idx - def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): """ Applies embedding based on inputs tensor. @@ -544,7 +552,9 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) else: position_ids = tf.expand_dims( tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index 920086d6d6..8357d09b55 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -982,7 +982,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): return outputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py index dff6324e6c..67a54e4f73 100644 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ b/src/transformers/models/mpnet/modeling_tf_mpnet.py @@ -729,7 +729,6 @@ class TFMPNetModel(TFMPNetPreTrainedModel): ) return outputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index 97496ec63a..60bb101d15 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -673,7 +673,6 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin attentions=transformer_outputs.attentions, ) - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py index 257231bc33..59c60ff6f3 100644 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ b/src/transformers/models/rembert/modeling_tf_rembert.py @@ -23,15 +23,16 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...file_utils import ( + DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, ) from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFCausalLMOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, TFMaskedLMOutput, TFMultipleChoiceModelOutput, TFQuestionAnsweringModelOutput, @@ -112,6 +113,7 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer): position_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None, inputs_embeds: tf.Tensor = None, + past_key_values_length=0, training: bool = False, ) -> tf.Tensor: """ @@ -131,7 +133,9 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer): token_type_ids = tf.fill(dims=input_shape, value=0) if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) @@ -170,6 +174,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) @@ -182,16 +188,49 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) + value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) @@ -221,6 +260,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -259,6 +300,9 @@ class TFRemBertAttention(tf.keras.layers.Layer): input_tensor: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: @@ -266,13 +310,17 @@ class TFRemBertAttention(tf.keras.layers.Layer): hidden_states=input_tensor, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] return outputs @@ -323,6 +371,12 @@ class TFRemBertLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.attention = TFRemBertAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFRemBertAttention(config, name="crossattention") self.intermediate = TFRemBertIntermediate(config, name="intermediate") self.bert_output = TFRemBertOutput(config, name="output") @@ -331,22 +385,69 @@ class TFRemBertLayer(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: - attention_outputs = self.attention( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( input_tensor=hidden_states, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, output_attentions=output_attentions, training=training, ) - attention_output = attention_outputs[0] + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers " + "by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + intermediate_output = self.intermediate(hidden_states=attention_output) layer_output = self.bert_output( hidden_states=intermediate_output, input_tensor=attention_output, training=training ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) return outputs @@ -354,6 +455,7 @@ class TFRemBertLayer(tf.keras.layers.Layer): class TFRemBertEncoder(tf.keras.layers.Layer): def __init__(self, config: RemBertConfig, **kwargs): super().__init__(**kwargs) + self.config = config self.embedding_hidden_mapping_in = tf.keras.layers.Dense( units=config.hidden_size, @@ -367,40 +469,62 @@ class TFRemBertEncoder(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_values: Tuple[Tuple[tf.Tensor]], + use_cache: bool, output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states) all_hidden_states = (hidden_states,) if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_value = past_key_values[i] if past_key_values is not None else None + layer_outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -500,6 +624,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.config = config + self.is_decoder = config.is_decoder self.embeddings = TFRemBertEmbeddings(config, name="embeddings") self.encoder = TFRemBertEncoder(config, name="encoder") @@ -519,6 +644,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): """ raise NotImplementedError + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call def call( self, input_ids: Optional[TFModelInputType] = None, @@ -527,12 +653,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: inputs = input_processing( func=self.call, config=self.config, @@ -542,6 +672,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -549,6 +683,9 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): kwargs_call=kwargs, ) + if not self.config.is_decoder: + inputs["use_cache"] = False + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: @@ -558,8 +695,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + + if inputs["past_key_values"] is None: + past_key_values_length = 0 + inputs["past_key_values"] = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) + inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) if inputs["token_type_ids"] is None: inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) @@ -569,6 +714,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): position_ids=inputs["position_ids"], token_type_ids=inputs["token_type_ids"], inputs_embeds=inputs["inputs_embeds"], + past_key_values_length=past_key_values_length, training=inputs["training"], ) @@ -577,7 +723,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) + attention_mask_shape = shape_list(inputs["attention_mask"]) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) + extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + else: + extended_attention_mask = tf.reshape( + inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -589,6 +757,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and inputs["encoder_attention_mask"] is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + inputs["encoder_attention_mask"] = tf.cast( + inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + ) + num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -603,6 +794,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): hidden_states=embedding_output, attention_mask=extended_attention_mask, head_mask=inputs["head_mask"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -613,13 +808,18 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None if not inputs["return_dict"]: - return (sequence_output, pooled_output) + encoder_outputs[1:] + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -632,6 +832,24 @@ class TFRemBertPreTrainedModel(TFPreTrainedModel): config_class = RemBertConfig base_model_prefix = "rembert" + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + REMBERT_START_DOCSTRING = r""" @@ -740,7 +958,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="rembert", - output_type=TFBaseModelOutputWithPooling, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -751,12 +969,36 @@ class TFRemBertModel(TFRemBertPreTrainedModel): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation + """ inputs = input_processing( func=self.call, config=self.config, @@ -766,6 +1008,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -779,6 +1025,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel): position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -787,15 +1037,25 @@ class TFRemBertModel(TFRemBertPreTrainedModel): return outputs - def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=output.last_hidden_state, pooler_output=output.pooler_output, + past_key_values=pkv, hidden_states=hs, attentions=attns, + cross_attentions=cross_attns, ) @@ -912,10 +1172,23 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos def get_lm_head(self) -> tf.keras.layers.Layer: return self.mlm.predictions + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + # cut decoder_input_ids if past is used + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": model_kwargs["use_cache"], + } + @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="rembert", - output_type=TFCausalLMOutput, + output_type=TFCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -926,14 +1199,36 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., config.vocab_size - 1]``. @@ -947,6 +1242,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -961,6 +1260,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -980,18 +1283,28 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return TFCausalLMOutput( + return TFCausalLMOutputWithCrossAttentions( loss=loss, logits=logits, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) - def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output + def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + return TFCausalLMOutputWithCrossAttentions( + logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns + ) @add_start_docstrings( diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py index e76597f3b9..91058cf040 100644 --- a/src/transformers/models/roberta/__init__.py +++ b/src/transformers/models/roberta/__init__.py @@ -45,6 +45,7 @@ if is_torch_available(): if is_tf_available(): _import_structure["modeling_tf_roberta"] = [ "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFRobertaForCausalLM", "TFRobertaForMaskedLM", "TFRobertaForMultipleChoice", "TFRobertaForQuestionAnswering", @@ -90,6 +91,7 @@ if TYPE_CHECKING: if is_tf_available(): from .modeling_tf_roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFRobertaForCausalLM, TFRobertaForMaskedLM, TFRobertaForMultipleChoice, TFRobertaForQuestionAnswering, diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 0de7be0846..e364ac691a 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -24,14 +24,16 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...file_utils import ( + DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, ) from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, TFMaskedLMOutput, TFMultipleChoiceModelOutput, TFQuestionAnsweringModelOutput, @@ -39,6 +41,7 @@ from ...modeling_tf_outputs import ( TFTokenClassifierOutput, ) from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, TFMaskedLanguageModelingLoss, TFModelInputType, TFMultipleChoiceLoss, @@ -112,7 +115,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): super().build(input_shape) - def create_position_ids_from_input_ids(self, input_ids): + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. @@ -122,11 +125,19 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): Returns: tf.Tensor """ mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = tf.math.cumsum(mask, axis=1) * mask + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask return incremental_indices + self.padding_idx - def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): """ Applies embedding based on inputs tensor. @@ -146,7 +157,9 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) else: position_ids = tf.expand_dims( tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 @@ -210,6 +223,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) @@ -222,16 +237,49 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) + value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) @@ -261,6 +309,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -299,6 +349,9 @@ class TFRobertaAttention(tf.keras.layers.Layer): input_tensor: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: @@ -306,13 +359,17 @@ class TFRobertaAttention(tf.keras.layers.Layer): hidden_states=input_tensor, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] return outputs @@ -363,6 +420,12 @@ class TFRobertaLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.attention = TFRobertaAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFRobertaAttention(config, name="crossattention") self.intermediate = TFRobertaIntermediate(config, name="intermediate") self.bert_output = TFRobertaOutput(config, name="output") @@ -371,22 +434,69 @@ class TFRobertaLayer(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: - attention_outputs = self.attention( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( input_tensor=hidden_states, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, output_attentions=output_attentions, training=training, ) - attention_output = attention_outputs[0] + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers " + "by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + intermediate_output = self.intermediate(hidden_states=attention_output) layer_output = self.bert_output( hidden_states=intermediate_output, input_tensor=attention_output, training=training ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) return outputs @@ -395,7 +505,7 @@ class TFRobertaLayer(tf.keras.layers.Layer): class TFRobertaEncoder(tf.keras.layers.Layer): def __init__(self, config: RobertaConfig, **kwargs): super().__init__(**kwargs) - + self.config = config self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] def call( @@ -403,39 +513,61 @@ class TFRobertaEncoder(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]], + use_cache: Optional[bool], output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_value = past_key_values[i] if past_key_values is not None else None + layer_outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -447,6 +579,8 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.config = config + self.is_decoder = config.is_decoder + self.num_hidden_layers = config.num_hidden_layers self.initializer_range = config.initializer_range self.output_attentions = config.output_attentions @@ -483,12 +617,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: inputs = input_processing( func=self.call, config=self.config, @@ -498,6 +636,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -505,6 +647,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): kwargs_call=kwargs, ) + if not self.config.is_decoder: + inputs["use_cache"] = False + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: @@ -514,8 +659,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + + if inputs["past_key_values"] is None: + past_key_values_length = 0 + inputs["past_key_values"] = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) + inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) if inputs["token_type_ids"] is None: inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) @@ -525,6 +678,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): position_ids=inputs["position_ids"], token_type_ids=inputs["token_type_ids"], inputs_embeds=inputs["inputs_embeds"], + past_key_values_length=past_key_values_length, training=inputs["training"], ) @@ -533,7 +687,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) + attention_mask_shape = shape_list(inputs["attention_mask"]) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) + extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + else: + extended_attention_mask = tf.reshape( + inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -545,6 +721,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and inputs["encoder_attention_mask"] is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + inputs["encoder_attention_mask"] = tf.cast( + inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + ) + num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -559,6 +758,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): hidden_states=embedding_output, attention_mask=extended_attention_mask, head_mask=inputs["head_mask"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -574,11 +777,13 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): pooled_output, ) + encoder_outputs[1:] - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -591,6 +796,25 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" + @property + # Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel.dummy_inputs + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + @tf.function( input_signature=[ { @@ -711,7 +935,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -722,12 +946,36 @@ class TFRobertaModel(TFRobertaPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, training=False, **kwargs, ): + r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation + """ inputs = input_processing( func=self.call, config=self.config, @@ -737,6 +985,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -750,6 +1002,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel): position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -759,15 +1015,24 @@ class TFRobertaModel(TFRobertaPreTrainedModel): return outputs # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output - def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFBaseModelOutputWithPooling( + return TFBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=output.last_hidden_state, pooler_output=output.pooler_output, + past_key_values=pkv, hidden_states=hs, attentions=attns, + cross_attentions=cross_attns, ) @@ -922,6 +1187,163 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) +class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config: RobertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + # cut decoder_input_ids if past is used + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": model_kwargs["use_cache"], + } + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + training=False, + **kwargs, + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation + labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., + config.vocab_size - 1]``. + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + training=training, + kwargs_call=kwargs, + ) + outputs = self.roberta( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + token_type_ids=inputs["token_type_ids"], + position_ids=inputs["position_ids"], + head_mask=inputs["head_mask"], + inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + + sequence_output = outputs[0] + logits = self.lm_head(hidden_states=sequence_output) + loss = None + + if inputs["labels"] is not None: + # shift labels to the left and cut last logit token + logits = logits[:, :-1] + labels = inputs["labels"][:, 1:] + loss = self.compute_loss(labels=labels, logits=logits) + + if not inputs["return_dict"]: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output + def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None + + return TFCausalLMOutputWithCrossAttentions( + logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns + ) + + class TFRobertaClassificationHead(tf.keras.layers.Layer): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py index 01dc6490ab..ff31ce0459 100644 --- a/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py @@ -18,6 +18,7 @@ from ...file_utils import add_start_docstrings from ...utils import logging from ..roberta.modeling_tf_roberta import ( + TFRobertaForCausalLM, TFRobertaForMaskedLM, TFRobertaForMultipleChoice, TFRobertaForQuestionAnswering, @@ -85,6 +86,19 @@ class TFXLMRobertaModel(TFRobertaModel): config_class = XLMRobertaConfig +@add_start_docstrings( + "XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.", + XLM_ROBERTA_START_DOCSTRING, +) +class XLMRobertaForCausalLM(TFRobertaForCausalLM): + """ + This class overrides :class:`~transformers.TFRobertaForCausalLM`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig + + @add_start_docstrings( """XLM-RoBERTa Model with a `language modeling` head on top. """, XLM_ROBERTA_START_DOCSTRING, diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 73e4afe59d..27301f5393 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -929,6 +929,15 @@ class TFElectraPreTrainedModel: requires_backends(cls, ["tf"]) +class TFEncoderDecoderModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1712,6 +1721,15 @@ class TFRemBertPreTrainedModel: TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None +class TFRobertaForCausalLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFRobertaForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 3825babc26..6732966fe1 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -24,15 +24,16 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...file_utils import ( + DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, ) from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFCausalLMOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, TFMaskedLMOutput, TFMultipleChoiceModelOutput, TFQuestionAnsweringModelOutput, @@ -116,6 +117,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): position_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None, inputs_embeds: tf.Tensor = None, + past_key_values_length=0, training: bool = False, ) -> tf.Tensor: """ @@ -135,7 +137,9 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): token_type_ids = tf.fill(dims=input_shape, value=0) if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) @@ -174,6 +178,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) @@ -186,16 +192,49 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) + value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) @@ -225,6 +264,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -263,6 +304,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): input_tensor: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: @@ -270,13 +314,17 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): hidden_states=input_tensor, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] return outputs @@ -327,6 +375,12 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): super().__init__(**kwargs) self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="crossattention") self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate") self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output") @@ -335,20 +389,69 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], output_attentions: bool, training: bool = False, ) -> Tuple[tf.Tensor]: - attention_outputs = self.attention( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( input_tensor=hidden_states, attention_mask=attention_mask, head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, output_attentions=output_attentions, training=training, ) - attention_output = attention_outputs[0] + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers " + "by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output(hidden_states=intermediate_output, input_tensor=attention_output, training=training) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) return outputs @@ -357,7 +460,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): super().__init__(**kwargs) - + self.config = config self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] def call( @@ -365,39 +468,61 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): hidden_states: tf.Tensor, attention_mask: tf.Tensor, head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]], + use_cache: Optional[bool], output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_value = past_key_values[i] if past_key_values is not None else None + layer_outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, training=training, ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -492,6 +617,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.config = config + self.is_decoder = config.is_decoder self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings") self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder") @@ -521,12 +647,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, **kwargs, - ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: inputs = input_processing( func=self.call, config=self.config, @@ -536,6 +666,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -543,6 +677,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): kwargs_call=kwargs, ) + if not self.config.is_decoder: + inputs["use_cache"] = False + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: @@ -552,8 +689,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + + if inputs["past_key_values"] is None: + past_key_values_length = 0 + inputs["past_key_values"] = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) + inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) if inputs["token_type_ids"] is None: inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) @@ -563,6 +708,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): position_ids=inputs["position_ids"], token_type_ids=inputs["token_type_ids"], inputs_embeds=inputs["inputs_embeds"], + past_key_values_length=past_key_values_length, training=inputs["training"], ) @@ -571,7 +717,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) + attention_mask_shape = shape_list(inputs["attention_mask"]) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) + extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + else: + extended_attention_mask = tf.reshape( + inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -583,6 +751,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and inputs["encoder_attention_mask"] is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + inputs["encoder_attention_mask"] = tf.cast( + inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + ) + num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -597,6 +788,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): hidden_states=embedding_output, attention_mask=extended_attention_mask, head_mask=inputs["head_mask"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -610,10 +805,12 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): sequence_output, ) + encoder_outputs[1:] - return TFBaseModelOutput( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=sequence_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -625,6 +822,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config base_model_prefix = "{{cookiecutter.lowercase_modelname}}" + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -732,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -743,12 +958,36 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation + """ inputs = input_processing( func=self.call, config=self.config, @@ -758,6 +997,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -771,6 +1014,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -779,12 +1026,26 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod return outputs - # Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output - def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput: + # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + cross_attentions=cross_attns, + ) @add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING) @@ -903,10 +1164,22 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca def get_lm_head(self) -> tf.keras.layers.Layer: return self.mlm.predictions + def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): + # cut decoder_input_ids if past is used + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": model_kwargs["use_cache"], + } + @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutput, + output_type=TFCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -917,14 +1190,36 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: r""" + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., config.vocab_size - 1]``. @@ -938,6 +1233,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -952,6 +1251,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca position_ids=inputs["position_ids"], head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -971,19 +1274,28 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return TFCausalLMOutput( + return TFCausalLMOutputWithCrossAttentions( loss=loss, logits=logits, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output - def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: + def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None - return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + return TFCausalLMOutputWithCrossAttentions( + logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns + ) class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer): diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py index c352809f0a..2d2e7a865f 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co from transformers.testing_utils import require_tf, slow from .test_configuration_common import ConfigTester -from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor if is_tf_available(): @@ -123,6 +123,33 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): diff --git a/tests/test_modeling_tf_bert.py b/tests/test_modeling_tf_bert.py index 639ba0be9d..47cf3f7300 100644 --- a/tests/test_modeling_tf_bert.py +++ b/tests/test_modeling_tf_bert.py @@ -21,7 +21,7 @@ from transformers.models.auto import get_values from transformers.testing_utils import require_tf, slow from .test_configuration_common import ConfigTester -from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor if is_tf_available(): @@ -125,6 +125,33 @@ class TFBertModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_bert_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2b7b2c9143..49cd6e4fb1 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1393,6 +1393,22 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): return output +def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None): + """Creates a random float32 tensor""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return tf.reshape(tf.constant(values, dtype=dtype if dtype is not None else tf.float32), shape=shape) + + @require_tf class UtilsFunctionsTest(unittest.TestCase): diff --git a/tests/test_modeling_tf_encoder_decoder.py b/tests/test_modeling_tf_encoder_decoder.py new file mode 100644 index 0000000000..a37d338ecb --- /dev/null +++ b/tests/test_modeling_tf_encoder_decoder.py @@ -0,0 +1,765 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import tempfile +import unittest + +import numpy as np + +from transformers import is_tf_available, is_torch_available +from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_torch, slow, torch_device + +from .test_modeling_tf_bert import TFBertModelTester +from .test_modeling_tf_common import ids_tensor +from .test_modeling_tf_rembert import TFRemBertModelTester +from .test_modeling_tf_roberta import TFRobertaModelTester + + +if is_tf_available(): + from transformers import ( + AutoConfig, + AutoTokenizer, + EncoderDecoderConfig, + TFAutoModel, + TFAutoModelForCausalLM, + TFBertLMHeadModel, + TFBertModel, + TFEncoderDecoderModel, + TFRemBertForCausalLM, + TFRemBertModel, + TFRobertaForCausalLM, + TFRobertaModel, + ) + from transformers.modeling_tf_outputs import TFBaseModelOutput + +if is_torch_available(): + import torch + + from transformers import BertLMHeadModel, BertModel, EncoderDecoderModel + + +@require_tf +class TFEncoderDecoderMixin: + def get_encoder_decoder_model(self, config, decoder_config): + raise NotImplementedError + + def prepare_config_and_inputs(self): + raise NotImplementedError + + def get_pretrained_model(self): + raise NotImplementedError + + def check_encoder_decoder_model_from_pretrained_configs( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + self.assertTrue(encoder_decoder_config.decoder.is_decoder) + + enc_dec_model = TFEncoderDecoderModel(encoder_decoder_config) + + self.assertTrue(enc_dec_model.config.is_encoder_decoder) + + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + + def check_encoder_decoder_model( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + self.assertTrue(enc_dec_model.config.decoder.is_decoder) + self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) + self.assertTrue(enc_dec_model.config.is_encoder_decoder) + + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + + encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_hidden_states) + outputs_encoder_decoder = enc_dec_model( + input_ids=None, + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + + def check_encoder_decoder_model_from_pretrained( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + return_dict, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict} + enc_dec_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + + def check_save_and_load( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + outputs = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + out_2 = np.array(outputs[0]) + out_2[np.isnan(out_2)] = 0 + + with tempfile.TemporaryDirectory() as tmpdirname: + enc_dec_model.save_pretrained(tmpdirname) + enc_dec_model = TFEncoderDecoderModel.from_pretrained(tmpdirname) + + after_outputs = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + out_1 = np.array(after_outputs[0]) + out_1[np.isnan(out_1)] = 0 + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def check_encoder_decoder_model_labels( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + ) + + # Make sure `loss` exist + assert "loss" in outputs_encoder_decoder + + batch_size, seq_len = decoder_input_ids.shape + expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size) + self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + + def check_encoder_decoder_model_output_attentions( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + ) + + encoder_attentions = outputs_encoder_decoder["encoder_attentions"] + self.assertEqual(len(encoder_attentions), config.num_hidden_layers) + + self.assertEqual( + encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]) + ) + + decoder_attentions = outputs_encoder_decoder["decoder_attentions"] + num_decoder_layers = ( + decoder_config.num_decoder_layers + if hasattr(decoder_config, "num_decoder_layers") + else decoder_config.num_hidden_layers + ) + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs_encoder_decoder["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = decoder_input_ids.shape[-1] * ( + 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) + ) + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), + ) + + def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + # Bert does not have a bos token id, so use pad_token_id instead + generated_output = enc_dec_model.generate( + input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id + ) + self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,)) + + def test_encoder_decoder_model(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model(**input_ids_dict) + + def test_encoder_decoder_model_from_pretrained_configs(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) + + def test_encoder_decoder_model_from_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False) + + def test_encoder_decoder_model_from_pretrained_return_dict(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True) + + def test_save_and_load_from_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_save_and_load(**input_ids_dict) + + def test_encoder_decoder_model_labels(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_labels(**input_ids_dict) + + def test_encoder_decoder_model_output_attentions(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_output_attentions(**input_ids_dict) + + def test_encoder_decoder_model_generate(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_generate(**input_ids_dict) + + @slow + def test_real_model_save_load_from_pretrained(self): + model_2 = self.get_pretrained_model() + input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size) + decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size) + attention_mask = ids_tensor([13, 5], vocab_size=2) + + outputs = model_2( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + ) + out_2 = np.array(outputs[0]) + out_2[np.isnan(out_2)] = 0 + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_2.save_pretrained(tmp_dirname) + model_1 = TFEncoderDecoderModel.from_pretrained(tmp_dirname) + + after_outputs = model_1( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + ) + out_1 = np.array(after_outputs[0]) + out_1[np.isnan(out_1)] = 0 + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + +@require_tf +class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): + def get_pretrained_model(self): + return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = TFBertModel(config, name="encoder") + decoder_model = TFBertLMHeadModel(decoder_config, name="decoder") + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = TFBertModelTester(self, batch_size=13) + model_tester_decoder = TFBertModelTester(self, batch_size=13) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + attention_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_token_type_ids, + decoder_attention_mask, + decoder_sequence_labels, + decoder_token_labels, + decoder_choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_token_type_ids": decoder_token_type_ids, + "decoder_attention_mask": decoder_attention_mask, + "decoder_sequence_labels": decoder_sequence_labels, + "decoder_token_labels": decoder_token_labels, + "decoder_choice_labels": decoder_choice_labels, + "encoder_hidden_states": encoder_hidden_states, + "labels": decoder_token_labels, + } + + @slow + @is_pt_tf_cross_test + def test_bert2bert_summarization(self): + + from transformers import EncoderDecoderModel + + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + """Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.encoder.layer...` + (For Bert decoder, there is no issue, because `BertModel` is wrapped into `decoder` as `bert`) + model = TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16", from_pt=True) + """ + + # workaround to load from pt + _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + _model.encoder.save_pretrained("./encoder") + _model.decoder.save_pretrained("./decoder") + model = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True + ) + model.config = _model.config + + ARTICLE_STUDENTS = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents.""" + EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months.""" + + input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf") + output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist() + summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) + + +@require_tf +class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): + def get_pretrained_model(self): + return TFEncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base") + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = TFRobertaModel(config, name="encoder") + decoder_model = TFRobertaForCausalLM(decoder_config, name="decoder") + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = TFRobertaModelTester(self) + model_tester_decoder = TFRobertaModelTester(self) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_token_type_ids, + decoder_input_mask, + decoder_sequence_labels, + decoder_token_labels, + decoder_choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_token_type_ids": decoder_token_type_ids, + "decoder_attention_mask": decoder_input_mask, + "decoder_sequence_labels": decoder_sequence_labels, + "decoder_token_labels": decoder_token_labels, + "decoder_choice_labels": decoder_choice_labels, + "encoder_hidden_states": encoder_hidden_states, + "labels": decoder_token_labels, + } + + +@require_tf +class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): + def get_pretrained_model(self): + return TFEncoderDecoderModel.from_encoder_decoder_pretrained("google/rembert", "google/rembert") + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = TFRemBertModel(config, name="encoder") + decoder_model = TFRemBertForCausalLM(decoder_config, name="decoder") + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = TFRemBertModelTester(self) + model_tester_decoder = TFRemBertModelTester(self) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_token_type_ids, + decoder_input_mask, + decoder_sequence_labels, + decoder_token_labels, + decoder_choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_token_type_ids": decoder_token_type_ids, + "decoder_attention_mask": decoder_input_mask, + "decoder_sequence_labels": decoder_sequence_labels, + "decoder_token_labels": decoder_token_labels, + "decoder_choice_labels": decoder_choice_labels, + "encoder_hidden_states": encoder_hidden_states, + "labels": decoder_token_labels, + } + + +@require_tf +class TFEncoderDecoderModelTest(unittest.TestCase): + def get_from_encoderdecoder_pretrained_model(self): + return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased") + + def get_decoder_config(self): + config = AutoConfig.from_pretrained("bert-base-cased") + config.is_decoder = True + config.add_cross_attention = True + return config + + def get_encoderdecoder_model(self): + return TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + + def get_encoder_decoder_models(self): + encoder_model = TFBertModel.from_pretrained("bert-base-cased", name="encoder") + decoder_model = TFBertLMHeadModel.from_pretrained( + "bert-base-cased", config=self.get_decoder_config(), name="decoder" + ) + return {"encoder": encoder_model, "decoder": decoder_model} + + def _check_configuration_tie(self, model): + assert id(model.decoder.config) == id(model.config.decoder) + assert id(model.encoder.config) == id(model.config.encoder) + + @slow + def test_configuration_tie(self): + model = self.get_from_encoderdecoder_pretrained_model() + self._check_configuration_tie(model) + + model = TFEncoderDecoderModel(**self.get_encoder_decoder_models()) + self._check_configuration_tie(model) + + # # This should be enabled once we upload the TF version of + # # "patrickvonplaten/bert2bert-cnn_dailymail-fp16" to the Hub. + # model = self.get_encoderdecoder_model() + # self._check_configuration_tie(model) + + +@require_tf +class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase): + def get_encoder_decoder_config(self): + encoder_config = AutoConfig.from_pretrained("bert-base-uncased") + decoder_config = AutoConfig.from_pretrained("bert-base-uncased", is_decoder=True, add_cross_attention=True) + return EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) + + def get_encoder_decoder_config_small(self): + encoder_config = AutoConfig.from_pretrained("hf-internal-testing/tiny-bert") + decoder_config = AutoConfig.from_pretrained( + "hf-internal-testing/tiny-bert", is_decoder=True, add_cross_attention=True + ) + return EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) + + def test_encoder_decoder_save_load_from_encoder_decoder(self): + config = self.get_encoder_decoder_config_small() + + # create two random BERT models for bert2bert & initialize weights (+cross_attention weights) + encoder = TFBertModel(config.encoder) + encoder(encoder.dummy_inputs) + decoder = TFBertLMHeadModel(config.decoder) + decoder(decoder.dummy_inputs) + + encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) + + input_ids = ids_tensor([13, 5], encoder.config.vocab_size) + decoder_input_ids = ids_tensor([13, 1], decoder.config.vocab_size) + + logits_orig = encoder_decoder_orig(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + with tempfile.TemporaryDirectory() as tmp_dirname: + encoder_path = os.path.join(tmp_dirname, "encoder") + decoder_path = os.path.join(tmp_dirname, "decoder") + + encoder.save_pretrained(encoder_path) + decoder.save_pretrained(decoder_path) + + encoder_decoder = TFEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_path, decoder_path) + + logits_1 = encoder_decoder(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + self.assertTrue(logits_orig.numpy().sum() - logits_1.numpy().sum() < 1e-3) + + max_diff = np.max(np.abs(logits_1.numpy() - logits_orig.numpy())) + self.assertAlmostEqual(max_diff, 0.0, places=4) + + with tempfile.TemporaryDirectory() as tmp_dirname: + encoder_decoder.save_pretrained(tmp_dirname) + encoder_decoder = TFEncoderDecoderModel.from_pretrained(tmp_dirname) + + logits_2 = encoder_decoder(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + max_diff = np.max(np.abs(logits_2.numpy() - logits_orig.numpy())) + self.assertAlmostEqual(max_diff, 0.0, places=4) + + @require_torch + @is_pt_tf_cross_test + def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self): + config = self.get_encoder_decoder_config_small() + + # create two random BERT models for bert2bert & initialize weights (+cross_attention weights) + encoder_pt = BertModel(config.encoder).to(torch_device).eval() + decoder_pt = BertLMHeadModel(config.decoder).to(torch_device).eval() + + encoder_decoder_pt = EncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval() + + input_ids = ids_tensor([13, 5], encoder_pt.config.vocab_size) + decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size) + + pt_input_ids = torch.tensor(input_ids.numpy(), device=torch_device, dtype=torch.long) + pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long) + + logits_pt = encoder_decoder_pt(input_ids=pt_input_ids, decoder_input_ids=pt_decoder_input_ids).logits + + # PyTorch => TensorFlow + with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2: + encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1) + encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2) + encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True + ) + + logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) + self.assertAlmostEqual(max_diff, 0.0, places=3) + + # TensorFlow => PyTorch + with tempfile.TemporaryDirectory() as tmp_dirname: + encoder_decoder_tf.save_pretrained(tmp_dirname) + encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True) + + max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) + self.assertAlmostEqual(max_diff, 0.0, places=3) + + @slow + def test_encoder_decoder_from_pretrained(self): + load_weight_prefix = "tf_encoder_decoder_model_1" + + config = self.get_encoder_decoder_config() + encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + decoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + input_ids = encoder_tokenizer("who sings does he love me with reba", return_tensors="tf").input_ids + decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + with tempfile.TemporaryDirectory() as tmp_dirname: + + # Since most of HF's models don't have pretrained cross-attention layers, they are randomly + # initialized even if we create models using `from_pretrained` method. + # For the tests, the decoder need to be a model with pretrained cross-attention layers. + # So we create pretrained models (without `load_weight_prefix`), save them, and later, + # we load them using `from_pretrained`. + # (we don't need to do this for encoder, but let's make the code more similar between encoder/decoder) + encoder = TFAutoModel.from_pretrained("bert-base-uncased", name="encoder") + # It's necessary to specify `add_cross_attention=True` here. + decoder = TFAutoModelForCausalLM.from_pretrained( + "bert-base-uncased", is_decoder=True, add_cross_attention=True, name="decoder" + ) + pretrained_encoder_dir = os.path.join(tmp_dirname, "pretrained_encoder") + pretrained_decoder_dir = os.path.join(tmp_dirname, "pretrained_decoder") + encoder.save_pretrained(pretrained_encoder_dir) + decoder.save_pretrained(pretrained_decoder_dir) + del encoder + del decoder + + enc_dec_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + pretrained_encoder_dir, + pretrained_decoder_dir, + ) + # check that the from pretrained methods work + enc_dec_model.save_pretrained(tmp_dirname) + enc_dec_model = TFEncoderDecoderModel.from_pretrained(tmp_dirname) + + output = enc_dec_model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids) + + loss_pretrained = output.loss + del enc_dec_model + + # Create the model using `__init__` with loaded ``pretrained`` encoder / decoder + encoder = TFAutoModel.from_pretrained( + pretrained_encoder_dir, load_weight_prefix=load_weight_prefix, name="encoder" + ) + decoder = TFAutoModelForCausalLM.from_pretrained( + pretrained_decoder_dir, load_weight_prefix=load_weight_prefix, name="decoder" + ) + enc_dec_model = TFEncoderDecoderModel(config=config, encoder=encoder, decoder=decoder) + + output = enc_dec_model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids) + + loss_init = output.loss + + max_diff = np.max(np.abs(loss_pretrained - loss_init)) + expected_diff = 0.0 + + self.assertAlmostEqual(max_diff, expected_diff, places=4) diff --git a/tests/test_modeling_tf_rembert.py b/tests/test_modeling_tf_rembert.py index cd09408ba9..8908e6d02b 100644 --- a/tests/test_modeling_tf_rembert.py +++ b/tests/test_modeling_tf_rembert.py @@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available from transformers.testing_utils import require_tf, slow from .test_configuration_common import ConfigTester -from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor if is_tf_available(): @@ -131,6 +131,33 @@ class TFRemBertModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): diff --git a/tests/test_modeling_tf_roberta.py b/tests/test_modeling_tf_roberta.py index d40652efc9..082df3b0bb 100644 --- a/tests/test_modeling_tf_roberta.py +++ b/tests/test_modeling_tf_roberta.py @@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow from .test_configuration_common import ConfigTester -from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor if is_tf_available(): @@ -29,6 +29,7 @@ if is_tf_available(): from transformers.models.roberta.modeling_tf_roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFRobertaForCausalLM, TFRobertaForMaskedLM, TFRobertaForMultipleChoice, TFRobertaForQuestionAnswering, @@ -101,6 +102,33 @@ class TFRobertaModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def create_and_check_roberta_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -115,6 +143,13 @@ class TFRobertaModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_roberta_for_causal_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFRobertaForCausalLM(config=config) + result = model([input_ids, input_mask, token_type_ids]) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_roberta_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -177,6 +212,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( ( TFRobertaModel, + TFRobertaForCausalLM, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TFRobertaForTokenClassification, @@ -203,6 +239,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs) + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_roberta_for_causal_lm(*config_and_inputs) + def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs) diff --git a/utils/check_repo.py b/utils/check_repo.py index 810bd1495f..fe5eac2f38 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -160,6 +160,7 @@ def get_model_modules(): "modeling_flax_utils", "modeling_transfo_xl_utilities", "modeling_tf_auto", + "modeling_tf_encoder_decoder", "modeling_tf_outputs", "modeling_tf_pytorch_utils", "modeling_tf_utils", @@ -231,6 +232,7 @@ def get_model_test_files(): "test_modeling_flax_encoder_decoder", "test_modeling_marian", "test_modeling_tf_common", + "test_modeling_tf_encoder_decoder", ] test_files = [] for filename in os.listdir(PATH_TO_TESTS):