Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222)
* Add cross attentions to TFGPT2Model * Add TFEncoderDecoderModel * Add TFBaseModelOutputWithPoolingAndCrossAttentions * Add cross attentions to TFBertModel * Fix past or past_key_values argument issue * Fix generation * Fix save and load * Add some checks and comments * Clean the code that deals with past keys/values * Add kwargs to processing_inputs * Add serving_output to TFEncoderDecoderModel * Some cleaning + fix use_cache value issue * Fix tests + add bert2bert/bert2gpt2 tests * Fix more tests * Ignore crossattention.bias when loading GPT2 weights into TFGPT2 * Fix return_dict_in_generate in tf generation * Fix is_token_logit_eos_token bug in tf generation * Finalize the tests after fixing some bugs * Fix another is_token_logit_eos_token bug in tf generation * Add/Update docs * Add TFBertEncoderDecoderModelTest * Clean test script * Add TFEncoderDecoderModel to the library * Add cross attentions to TFRobertaModel * Add TFRobertaEncoderDecoderModelTest * make style * Change the way of position_ids computation * bug fix * Fix copies in tf_albert * Remove some copied from and apply some fix-copies * Remove some copied * Add cross attentions to some other TF models * Remove encoder_hidden_states from TFLayoutLMModel.call for now * Make style * Fix TFRemBertForCausalLM * Revert the change to longformer + Remove copies * Revert the change to albert and convbert + Remove copies * make quality * make style * Add TFRembertEncoderDecoderModelTest * make quality and fix-copies * test TFRobertaForCausalLM * Fixes for failed tests * Fixes for failed tests * fix more tests * Fixes for failed tests * Fix Auto mapping order * Fix TFRemBertEncoder return value * fix tf_rembert * Check copies are OK * Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined * Add TFEncoderDecoderModelSaveLoadTests * fix tf weight loading * check the change of use_cache * Revert the change * Add missing test_for_causal_lm for TFRobertaModelTest * Try cleaning past * fix _reorder_cache * Revert some files to original versions * Keep as many copies as possible * Apply suggested changes - Use raise ValueError instead of assert * Move import to top * Fix wrong require_torch * Replace more assert by raise ValueError * Add test_pt_tf_model_equivalence (the test won't pass for now) * add test for loading/saving * finish * finish * Remove test_pt_tf_model_equivalence * Update tf modeling template * Remove pooling, added in the prev. commit, from MainLayer * Update tf modeling test template * Move inputs["use_cache"] = False to modeling_tf_utils.py * Fix torch.Tensor in the comment * fix use_cache * Fix missing use_cache in ElectraConfig * Add a note to from_pretrained * Fix style * Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt * Fix TFMLP (in TFGPT2) activation issue * Fix None past_key_values value in serving_output * Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub * Apply review suggestions - style for cross_attns in serving_output * Apply review suggestions - change assert + docstrings * break the error message to respect the char limit * deprecate the argument past * fix docstring style * Update the encoder-decoder rst file * fix Unknown interpreted text role "method" * fix typo Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -379,7 +379,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
|||||||
@@ -210,6 +210,13 @@ TFBaseModelOutputWithPooling
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFBaseModelOutputWithPast
|
TFBaseModelOutputWithPast
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -217,6 +224,13 @@ TFBaseModelOutputWithPast
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFBaseModelOutputWithPastAndCrossAttentions
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPastAndCrossAttentions
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFSeq2SeqModelOutput
|
TFSeq2SeqModelOutput
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -231,6 +245,13 @@ TFCausalLMOutput
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFCausalLMOutputWithCrossAttentions
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFCausalLMOutputWithPast
|
TFCausalLMOutputWithPast
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,25 @@ An application of this architecture could be to leverage two pretrained :class:`
|
|||||||
and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders
|
and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders
|
||||||
<https://arxiv.org/abs/1908.08345>`__ by Yang Liu and Mirella Lapata.
|
<https://arxiv.org/abs/1908.08345>`__ by Yang Liu and Mirella Lapata.
|
||||||
|
|
||||||
|
The :meth:`~transformers.TFEncoderDecoderModel.from_pretrained` currently doesn't support initializing the model from a
|
||||||
|
pytorch checkpoint. Passing ``from_pt=True`` to this method will throw an exception. If there are only pytorch
|
||||||
|
checkpoints for a particular encoder-decoder model, a workaround is:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
>>> # a workaround to load from pytorch checkpoint
|
||||||
|
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
>>> _model.encoder.save_pretrained("./encoder")
|
||||||
|
>>> _model.decoder.save_pretrained("./decoder")
|
||||||
|
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
|
||||||
|
... )
|
||||||
|
>>> # This is only for copying some specific attributes of this particular model.
|
||||||
|
>>> model.config = _model.config
|
||||||
|
|
||||||
|
This model was contributed by `thomwolf <https://github.com/thomwolf>`__. This model's TensorFlow and Flax versions
|
||||||
|
were contributed by `ydshieh <https://github.com/ydshieh>`__.
|
||||||
|
|
||||||
|
|
||||||
EncoderDecoderConfig
|
EncoderDecoderConfig
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
@@ -42,6 +61,13 @@ EncoderDecoderModel
|
|||||||
:members: forward, from_encoder_decoder_pretrained
|
:members: forward, from_encoder_decoder_pretrained
|
||||||
|
|
||||||
|
|
||||||
|
TFEncoderDecoderModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFEncoderDecoderModel
|
||||||
|
:members: call, from_encoder_decoder_pretrained
|
||||||
|
|
||||||
|
|
||||||
FlaxEncoderDecoderModel
|
FlaxEncoderDecoderModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,13 @@ TFRobertaModel
|
|||||||
:members: call
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
|
TFRobertaForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFRobertaForCausalLM
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
TFRobertaForMaskedLM
|
TFRobertaForMaskedLM
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -1444,6 +1444,7 @@ if is_tf_available():
|
|||||||
"TFElectraPreTrainedModel",
|
"TFElectraPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
|
||||||
_import_structure["models.flaubert"].extend(
|
_import_structure["models.flaubert"].extend(
|
||||||
[
|
[
|
||||||
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1596,6 +1597,7 @@ if is_tf_available():
|
|||||||
_import_structure["models.roberta"].extend(
|
_import_structure["models.roberta"].extend(
|
||||||
[
|
[
|
||||||
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFRobertaForCausalLM",
|
||||||
"TFRobertaForMaskedLM",
|
"TFRobertaForMaskedLM",
|
||||||
"TFRobertaForMultipleChoice",
|
"TFRobertaForMultipleChoice",
|
||||||
"TFRobertaForQuestionAnswering",
|
"TFRobertaForQuestionAnswering",
|
||||||
@@ -3096,6 +3098,7 @@ if TYPE_CHECKING:
|
|||||||
TFElectraModel,
|
TFElectraModel,
|
||||||
TFElectraPreTrainedModel,
|
TFElectraPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.encoder_decoder import TFEncoderDecoderModel
|
||||||
from .models.flaubert import (
|
from .models.flaubert import (
|
||||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFFlaubertForMultipleChoice,
|
TFFlaubertForMultipleChoice,
|
||||||
@@ -3205,6 +3208,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.roberta import (
|
from .models.roberta import (
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForMultipleChoice,
|
TFRobertaForMultipleChoice,
|
||||||
TFRobertaForQuestionAnswering,
|
TFRobertaForQuestionAnswering,
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ from . import (
|
|||||||
TFLxmertForPreTraining,
|
TFLxmertForPreTraining,
|
||||||
TFLxmertVisualFeatureEncoder,
|
TFLxmertVisualFeatureEncoder,
|
||||||
TFOpenAIGPTLMHeadModel,
|
TFOpenAIGPTLMHeadModel,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFT5ForConditionalGeneration,
|
TFT5ForConditionalGeneration,
|
||||||
@@ -215,6 +216,7 @@ MODEL_CLASSES = {
|
|||||||
),
|
),
|
||||||
"roberta": (
|
"roberta": (
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
|||||||
@@ -650,7 +650,11 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# current position and vocab size
|
# current position and vocab size
|
||||||
cur_len = shape_list(input_ids)[1] # unused
|
cur_len = shape_list(input_ids)[1] # unused
|
||||||
vocab_size = self.config.vocab_size
|
vocab_size = getattr(self.config, "vocab_size", None)
|
||||||
|
if vocab_size is None and self.config.is_encoder_decoder:
|
||||||
|
decoder_config = getattr(self.config, "decoder", None)
|
||||||
|
if decoder_config is not None:
|
||||||
|
vocab_size = getattr(self.config.decoder, "vocab_size", None)
|
||||||
|
|
||||||
# set effective batch size and effective batch multiplier according to do_sample
|
# set effective batch size and effective batch multiplier according to do_sample
|
||||||
if do_sample:
|
if do_sample:
|
||||||
@@ -678,6 +682,7 @@ class TFGenerationMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict_in_generate,
|
||||||
)
|
)
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@@ -911,7 +916,7 @@ class TFGenerationMixin:
|
|||||||
if eos_token_id is not None and cur_len < min_length:
|
if eos_token_id is not None and cur_len < min_length:
|
||||||
# create eos_token_id boolean mask
|
# create eos_token_id boolean mask
|
||||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
)
|
)
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
@@ -1142,7 +1147,7 @@ class TFGenerationMixin:
|
|||||||
num_batch_hypotheses = batch_size * num_beams
|
num_batch_hypotheses = batch_size * num_beams
|
||||||
|
|
||||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
)
|
)
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
|
||||||
|
|
||||||
@@ -1446,11 +1451,17 @@ class TFGenerationMixin:
|
|||||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||||
the generate method.
|
the generate method.
|
||||||
"""
|
"""
|
||||||
|
vocab_size = getattr(self.config, "vocab_size", None)
|
||||||
|
if vocab_size is None and self.config.is_encoder_decoder:
|
||||||
|
decoder_config = getattr(self.config, "decoder", None)
|
||||||
|
if decoder_config is not None:
|
||||||
|
vocab_size = getattr(self.config.decoder, "vocab_size", None)
|
||||||
|
|
||||||
if cur_len == 1 and forced_bos_token_id is not None:
|
if cur_len == 1 and forced_bos_token_id is not None:
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
vocab_range = tf.constant(range(vocab_size))
|
||||||
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
|
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
|
||||||
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
|
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
|
||||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
vocab_range = tf.constant(range(vocab_size))
|
||||||
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
|
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -80,6 +80,54 @@ class TFBaseModelOutputWithPooling(ModelOutput):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
|
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
|
||||||
|
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
|
||||||
|
prediction (classification) objective during pretraining.
|
||||||
|
|
||||||
|
This output is usually *not* a good summary of the semantic content of the input, you're often better with
|
||||||
|
averaging or pooling the sequence of hidden-states for the whole input sequence.
|
||||||
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
|
||||||
|
num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: tf.Tensor = None
|
||||||
|
pooler_output: tf.Tensor = None
|
||||||
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TFBaseModelOutputWithPast(ModelOutput):
|
class TFBaseModelOutputWithPast(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -317,6 +365,49 @@ class TFCausalLMOutputWithPast(ModelOutput):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TFCausalLMOutputWithCrossAttentions(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for causal language model (or autoregressive) outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:obj:`tf.Tensor` of shape :obj:`(n,)`, `optional`, where n is the number of non-masked labels, returned when :obj:`labels` is provided):
|
||||||
|
Language modeling loss (for next-token prediction).
|
||||||
|
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
|
||||||
|
num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[tf.Tensor] = None
|
||||||
|
logits: tf.Tensor = None
|
||||||
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TFMaskedLMOutput(ModelOutput):
|
class TFMaskedLMOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -296,7 +296,9 @@ def booleans_processing(config, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "use_cache" in kwargs:
|
if "use_cache" in kwargs:
|
||||||
final_booleans["use_cache"] = kwargs["use_cache"] if kwargs["use_cache"] is not None else config.use_cache
|
final_booleans["use_cache"] = (
|
||||||
|
kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
kwargs["output_attentions"] not in (None, config.output_attentions)
|
kwargs["output_attentions"] not in (None, config.output_attentions)
|
||||||
@@ -318,7 +320,7 @@ def booleans_processing(config, **kwargs):
|
|||||||
final_booleans["return_dict"] = True
|
final_booleans["return_dict"] = True
|
||||||
|
|
||||||
if "use_cache" in kwargs:
|
if "use_cache" in kwargs:
|
||||||
final_booleans["use_cache"] = config.use_cache
|
final_booleans["use_cache"] = getattr(config, "use_cache", None)
|
||||||
|
|
||||||
return final_booleans
|
return final_booleans
|
||||||
|
|
||||||
@@ -362,6 +364,15 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
)
|
)
|
||||||
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
|
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
|
||||||
|
|
||||||
|
if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
|
||||||
|
elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs:
|
||||||
|
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
|
||||||
|
|
||||||
if len(kwargs["kwargs_call"]) > 0:
|
if len(kwargs["kwargs_call"]) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
|
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||||||
position_ids: tf.Tensor = None,
|
position_ids: tf.Tensor = None,
|
||||||
token_type_ids: tf.Tensor = None,
|
token_type_ids: tf.Tensor = None,
|
||||||
inputs_embeds: tf.Tensor = None,
|
inputs_embeds: tf.Tensor = None,
|
||||||
|
past_key_values_length=0,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -177,7 +178,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(
|
||||||
|
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||||
@@ -829,7 +832,6 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
# Model for Causal LM mapping
|
# Model for Causal LM mapping
|
||||||
("rembert", "TFRemBertForCausalLM"),
|
("rembert", "TFRemBertForCausalLM"),
|
||||||
("roformer", "TFRoFormerForCausalLM"),
|
("roformer", "TFRoFormerForCausalLM"),
|
||||||
|
("roberta", "TFRobertaForCausalLM"),
|
||||||
("bert", "TFBertLMHeadModel"),
|
("bert", "TFBertLMHeadModel"),
|
||||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||||
("gpt2", "TFGPT2LMHeadModel"),
|
("gpt2", "TFGPT2LMHeadModel"),
|
||||||
@@ -181,6 +182,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
||||||
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
||||||
("bart", "TFBartForConditionalGeneration"),
|
("bart", "TFBartForConditionalGeneration"),
|
||||||
|
("encoder-decoder", "TFEncoderDecoderModel"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
DUMMY_INPUTS,
|
||||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -33,9 +34,9 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFBaseModelOutputWithPooling,
|
TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
TFCausalLMOutput,
|
TFCausalLMOutputWithCrossAttentions,
|
||||||
TFMaskedLMOutput,
|
TFMaskedLMOutput,
|
||||||
TFMultipleChoiceModelOutput,
|
TFMultipleChoiceModelOutput,
|
||||||
TFNextSentencePredictorOutput,
|
TFNextSentencePredictorOutput,
|
||||||
@@ -174,6 +175,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
position_ids: tf.Tensor = None,
|
position_ids: tf.Tensor = None,
|
||||||
token_type_ids: tf.Tensor = None,
|
token_type_ids: tf.Tensor = None,
|
||||||
inputs_embeds: tf.Tensor = None,
|
inputs_embeds: tf.Tensor = None,
|
||||||
|
past_key_values_length=0,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -194,7 +196,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(
|
||||||
|
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||||
@@ -232,6 +236,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
@@ -244,16 +250,49 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(inputs=hidden_states)
|
mixed_query_layer = self.query(inputs=hidden_states)
|
||||||
mixed_key_layer = self.key(inputs=hidden_states)
|
|
||||||
mixed_value_layer = self.value(inputs=hidden_states)
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
@@ -283,6 +322,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -319,6 +360,9 @@ class TFBertAttention(tf.keras.layers.Layer):
|
|||||||
input_tensor: tf.Tensor,
|
input_tensor: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
@@ -326,13 +370,17 @@ class TFBertAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output(
|
attention_output = self.dense_output(
|
||||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||||
)
|
)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
# add attentions (possibly with past_key_value) if we output them
|
||||||
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -380,6 +428,12 @@ class TFBertLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.attention = TFBertAttention(config, name="attention")
|
self.attention = TFBertAttention(config, name="attention")
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
if not self.is_decoder:
|
||||||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||||
|
self.crossattention = TFBertAttention(config, name="crossattention")
|
||||||
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
||||||
self.bert_output = TFBertOutput(config, name="output")
|
self.bert_output = TFBertOutput(config, name="output")
|
||||||
|
|
||||||
@@ -388,22 +442,69 @@ class TFBertLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
attention_outputs = self.attention(
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
input_tensor=hidden_states,
|
input_tensor=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "crossattention"):
|
||||||
|
raise ValueError(
|
||||||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||||
|
"by setting `config.add_cross_attention=True`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
input_tensor=attention_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=cross_attn_past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||||
layer_output = self.bert_output(
|
layer_output = self.bert_output(
|
||||||
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -411,7 +512,7 @@ class TFBertLayer(tf.keras.layers.Layer):
|
|||||||
class TFBertEncoder(tf.keras.layers.Layer):
|
class TFBertEncoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: BertConfig, **kwargs):
|
def __init__(self, config: BertConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -419,39 +520,61 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||||
|
use_cache: Optional[bool],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
output_hidden_states: bool,
|
output_hidden_states: bool,
|
||||||
return_dict: bool,
|
return_dict: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -579,6 +702,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||||
self.encoder = TFBertEncoder(config, name="encoder")
|
self.encoder = TFBertEncoder(config, name="encoder")
|
||||||
@@ -606,12 +730,16 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -621,6 +749,10 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -628,6 +760,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.is_decoder:
|
||||||
|
inputs["use_cache"] = False
|
||||||
|
|
||||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif inputs["input_ids"] is not None:
|
elif inputs["input_ids"] is not None:
|
||||||
@@ -637,8 +772,16 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
if inputs["past_key_values"] is None:
|
||||||
|
past_key_values_length = 0
|
||||||
|
inputs["past_key_values"] = [None] * len(self.encoder.layer)
|
||||||
|
else:
|
||||||
|
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||||
|
|
||||||
if inputs["attention_mask"] is None:
|
if inputs["attention_mask"] is None:
|
||||||
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
|
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||||
|
|
||||||
if inputs["token_type_ids"] is None:
|
if inputs["token_type_ids"] is None:
|
||||||
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
||||||
@@ -648,6 +791,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
token_type_ids=inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -656,7 +800,29 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||||
|
|
||||||
|
mask_seq_length = seq_length + past_key_values_length
|
||||||
|
# Copied from `modeling_tf_t5.py`
|
||||||
|
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
if self.is_decoder:
|
||||||
|
seq_ids = tf.range(mask_seq_length)
|
||||||
|
causal_mask = tf.less_equal(
|
||||||
|
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||||
|
seq_ids[None, :, None],
|
||||||
|
)
|
||||||
|
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
|
||||||
|
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
|
||||||
|
attention_mask_shape = shape_list(extended_attention_mask)
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -668,6 +834,29 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
|
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
||||||
|
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
inputs["encoder_attention_mask"] = tf.cast(
|
||||||
|
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||||
|
)
|
||||||
|
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||||
|
if num_dims_encoder_attention_mask == 3:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||||
|
if num_dims_encoder_attention_mask == 2:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||||
|
|
||||||
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
|
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||||
|
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
@@ -682,6 +871,10 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=embedding_output,
|
hidden_states=embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -697,11 +890,13 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
pooled_output,
|
pooled_output,
|
||||||
) + encoder_outputs[1:]
|
) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -714,6 +909,24 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
|
|||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
"""
|
||||||
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
|
"""
|
||||||
|
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||||
|
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||||
|
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||||
|
h = tf.random.uniform(shape=shape)
|
||||||
|
dummy["encoder_hidden_states"] = h
|
||||||
|
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TFBertForPreTrainingOutput(ModelOutput):
|
class TFBertForPreTrainingOutput(ModelOutput):
|
||||||
@@ -853,7 +1066,7 @@ class TFBertModel(TFBertPreTrainedModel):
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TFBaseModelOutputWithPooling,
|
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -864,12 +1077,36 @@ class TFBertModel(TFBertPreTrainedModel):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -879,6 +1116,10 @@ class TFBertModel(TFBertPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -892,6 +1133,10 @@ class TFBertModel(TFBertPreTrainedModel):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -900,15 +1145,24 @@ class TFBertModel(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
def serving_output(
|
||||||
|
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=output.last_hidden_state,
|
last_hidden_state=output.last_hidden_state,
|
||||||
pooler_output=output.pooler_output,
|
pooler_output=output.pooler_output,
|
||||||
|
past_key_values=pkv,
|
||||||
hidden_states=hs,
|
hidden_states=hs,
|
||||||
attentions=attns,
|
attentions=attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -960,11 +1214,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBertForPreTrainingOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBertForPreTrainingOutput, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
|
labels (:obj:`tf.Tensor` of shape ``(batch_size, sequence_length)``, `optional`):
|
||||||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
||||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||||
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
|
next_sentence_label (``tf.Tensor`` of shape ``(batch_size,)``, `optional`):
|
||||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||||
(see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
|
(see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
|
||||||
|
|
||||||
@@ -1184,10 +1438,22 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past:
|
||||||
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": inputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": model_kwargs["use_cache"],
|
||||||
|
}
|
||||||
|
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TFCausalLMOutput,
|
output_type=TFCausalLMOutputWithCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -1198,14 +1464,36 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||||
config.vocab_size - 1]``.
|
config.vocab_size - 1]``.
|
||||||
@@ -1219,6 +1507,10 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1233,6 +1525,10 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -1252,18 +1548,27 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return TFCausalLMOutput(
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
position_ids: tf.Tensor = None,
|
position_ids: tf.Tensor = None,
|
||||||
token_type_ids: tf.Tensor = None,
|
token_type_ids: tf.Tensor = None,
|
||||||
inputs_embeds: tf.Tensor = None,
|
inputs_embeds: tf.Tensor = None,
|
||||||
|
past_key_values_length=0,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -130,7 +131,9 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(
|
||||||
|
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ class ElectraConfig(PretrainedConfig):
|
|||||||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
||||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
||||||
<https://arxiv.org/abs/2009.13658>`__.
|
<https://arxiv.org/abs/2009.13658>`__.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if ``config.is_decoder=True``.
|
||||||
classifier_dropout (:obj:`float`, `optional`):
|
classifier_dropout (:obj:`float`, `optional`):
|
||||||
The dropout ratio for the classification head.
|
The dropout ratio for the classification head.
|
||||||
|
|
||||||
@@ -143,6 +146,7 @@ class ElectraConfig(PretrainedConfig):
|
|||||||
summary_last_dropout=0.1,
|
summary_last_dropout=0.1,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
position_embedding_type="absolute",
|
position_embedding_type="absolute",
|
||||||
|
use_cache=True,
|
||||||
classifier_dropout=None,
|
classifier_dropout=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -167,4 +171,5 @@ class ElectraConfig(PretrainedConfig):
|
|||||||
self.summary_activation = summary_activation
|
self.summary_activation = summary_activation
|
||||||
self.summary_last_dropout = summary_last_dropout
|
self.summary_last_dropout = summary_last_dropout
|
||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
|
self.use_cache = use_cache
|
||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
DUMMY_INPUTS,
|
||||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -31,7 +32,7 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFMaskedLMOutput,
|
TFMaskedLMOutput,
|
||||||
TFMultipleChoiceModelOutput,
|
TFMultipleChoiceModelOutput,
|
||||||
TFQuestionAnsweringModelOutput,
|
TFQuestionAnsweringModelOutput,
|
||||||
@@ -99,6 +100,8 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
@@ -111,16 +114,49 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(inputs=hidden_states)
|
mixed_query_layer = self.query(inputs=hidden_states)
|
||||||
mixed_key_layer = self.key(inputs=hidden_states)
|
|
||||||
mixed_value_layer = self.value(inputs=hidden_states)
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
@@ -150,6 +186,8 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -188,6 +226,9 @@ class TFElectraAttention(tf.keras.layers.Layer):
|
|||||||
input_tensor: tf.Tensor,
|
input_tensor: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
@@ -195,13 +236,17 @@ class TFElectraAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output(
|
attention_output = self.dense_output(
|
||||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||||
)
|
)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
# add attentions (possibly with past_key_value) if we output them
|
||||||
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -252,6 +297,12 @@ class TFElectraLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.attention = TFElectraAttention(config, name="attention")
|
self.attention = TFElectraAttention(config, name="attention")
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
if not self.is_decoder:
|
||||||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||||
|
self.crossattention = TFElectraAttention(config, name="crossattention")
|
||||||
self.intermediate = TFElectraIntermediate(config, name="intermediate")
|
self.intermediate = TFElectraIntermediate(config, name="intermediate")
|
||||||
self.bert_output = TFElectraOutput(config, name="output")
|
self.bert_output = TFElectraOutput(config, name="output")
|
||||||
|
|
||||||
@@ -260,22 +311,69 @@ class TFElectraLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
attention_outputs = self.attention(
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
input_tensor=hidden_states,
|
input_tensor=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "crossattention"):
|
||||||
|
raise ValueError(
|
||||||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||||
|
"by setting `config.add_cross_attention=True`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
input_tensor=attention_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=cross_attn_past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||||
layer_output = self.bert_output(
|
layer_output = self.bert_output(
|
||||||
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -284,7 +382,7 @@ class TFElectraLayer(tf.keras.layers.Layer):
|
|||||||
class TFElectraEncoder(tf.keras.layers.Layer):
|
class TFElectraEncoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: ElectraConfig, **kwargs):
|
def __init__(self, config: ElectraConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -292,39 +390,61 @@ class TFElectraEncoder(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||||
|
use_cache: Optional[bool],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
output_hidden_states: bool,
|
output_hidden_states: bool,
|
||||||
return_dict: bool,
|
return_dict: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -396,6 +516,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||||||
position_ids: tf.Tensor = None,
|
position_ids: tf.Tensor = None,
|
||||||
token_type_ids: tf.Tensor = None,
|
token_type_ids: tf.Tensor = None,
|
||||||
inputs_embeds: tf.Tensor = None,
|
inputs_embeds: tf.Tensor = None,
|
||||||
|
past_key_values_length=0,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -416,7 +537,9 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(
|
||||||
|
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||||
@@ -471,6 +594,25 @@ class TFElectraPreTrainedModel(TFPreTrainedModel):
|
|||||||
_keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
|
_keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
|
||||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel.dummy_inputs
|
||||||
|
def dummy_inputs(self):
|
||||||
|
"""
|
||||||
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
|
"""
|
||||||
|
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||||
|
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||||
|
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||||
|
h = tf.random.uniform(shape=shape)
|
||||||
|
dummy["encoder_hidden_states"] = h
|
||||||
|
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFElectraMainLayer(tf.keras.layers.Layer):
|
class TFElectraMainLayer(tf.keras.layers.Layer):
|
||||||
@@ -480,13 +622,14 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
|
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
|
||||||
|
|
||||||
if config.embedding_size != config.hidden_size:
|
if config.embedding_size != config.hidden_size:
|
||||||
self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
|
self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
|
||||||
|
|
||||||
self.encoder = TFElectraEncoder(config, name="encoder")
|
self.encoder = TFElectraEncoder(config, name="encoder")
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings
|
return self.embeddings
|
||||||
@@ -502,24 +645,50 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
|
def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0):
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.fill(input_shape, 1)
|
attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
attention_mask_shape = shape_list(attention_mask)
|
||||||
|
|
||||||
|
mask_seq_length = seq_length + past_key_values_length
|
||||||
|
# Copied from `modeling_tf_t5.py`
|
||||||
|
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
if self.is_decoder:
|
||||||
|
seq_ids = tf.range(mask_seq_length)
|
||||||
|
causal_mask = tf.less_equal(
|
||||||
|
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||||
|
seq_ids[None, :, None],
|
||||||
|
)
|
||||||
|
causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
|
||||||
|
extended_attention_mask = causal_mask * attention_mask[:, None, :]
|
||||||
|
attention_mask_shape = shape_list(extended_attention_mask)
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
extended_attention_mask = tf.cast(extended_attention_mask, dtype)
|
extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype)
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
one_cst = tf.constant(1.0, dtype=dtype)
|
||||||
|
ten_thousand_cst = tf.constant(-10000.0, dtype=dtype)
|
||||||
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
return extended_attention_mask
|
return extended_attention_mask
|
||||||
|
|
||||||
@@ -539,6 +708,10 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -554,6 +727,10 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -561,6 +738,9 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.is_decoder:
|
||||||
|
inputs["use_cache"] = False
|
||||||
|
|
||||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif inputs["input_ids"] is not None:
|
elif inputs["input_ids"] is not None:
|
||||||
@@ -570,34 +750,71 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
if inputs["past_key_values"] is None:
|
||||||
|
past_key_values_length = 0
|
||||||
|
inputs["past_key_values"] = [None] * len(self.encoder.layer)
|
||||||
|
else:
|
||||||
|
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||||
|
|
||||||
if inputs["attention_mask"] is None:
|
if inputs["attention_mask"] is None:
|
||||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||||
|
|
||||||
if inputs["token_type_ids"] is None:
|
if inputs["token_type_ids"] is None:
|
||||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
extended_attention_mask = self.get_extended_attention_mask(
|
extended_attention_mask = self.get_extended_attention_mask(
|
||||||
inputs["attention_mask"], input_shape, hidden_states.dtype
|
inputs["attention_mask"], input_shape, hidden_states.dtype, past_key_values_length
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
||||||
|
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
inputs["encoder_attention_mask"] = tf.cast(
|
||||||
|
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||||
|
)
|
||||||
|
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||||
|
if num_dims_encoder_attention_mask == 3:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||||
|
if num_dims_encoder_attention_mask == 2:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||||
|
|
||||||
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
|
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||||
|
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
|
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
|
||||||
|
|
||||||
if hasattr(self, "embeddings_project"):
|
if hasattr(self, "embeddings_project"):
|
||||||
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
|
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
hidden_states = self.encoder(
|
hidden_states = self.encoder(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs["output_attentions"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
inputs["output_hidden_states"],
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
inputs["return_dict"],
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
|
output_attentions=inputs["output_attentions"],
|
||||||
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -735,7 +952,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TFBaseModelOutput,
|
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -746,12 +963,36 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
training=False,
|
training=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -761,6 +1002,10 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -773,6 +1018,10 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
|||||||
token_type_ids=inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
@@ -782,12 +1031,22 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
|
|
||||||
def serving_output(self, output):
|
def serving_output(self, output):
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=output.last_hidden_state,
|
||||||
|
past_key_values=pkv,
|
||||||
|
hidden_states=hs,
|
||||||
|
attentions=attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -28,6 +28,9 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]
|
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
_import_structure["modeling_tf_encoder_decoder"] = ["TFEncoderDecoderModel"]
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]
|
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]
|
||||||
|
|
||||||
@@ -37,6 +40,9 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_encoder_decoder import EncoderDecoderModel
|
from .modeling_encoder_decoder import EncoderDecoderModel
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from .modeling_tf_encoder_decoder import TFEncoderDecoderModel
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
|
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,647 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Classes to support TF Encoder-Decoder architectures """
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...file_utils import (
|
||||||
|
DUMMY_INPUTS,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||||
|
from ...modeling_tf_utils import TFPreTrainedModel, input_processing
|
||||||
|
from ...utils import logging
|
||||||
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||||
|
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
|
||||||
|
|
||||||
|
ENCODER_DECODER_START_DOCSTRING = r"""
|
||||||
|
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
|
||||||
|
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
|
||||||
|
:meth:`~transformers.TFAutoModel.from_pretrained` function and the decoder is loaded via
|
||||||
|
:meth:`~transformers.TFAutoModelForCausalLM.from_pretrained` function. Cross-attention layers are automatically
|
||||||
|
added to the decoder and should be fine-tuned on a downstream generative task, like summarization.
|
||||||
|
|
||||||
|
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
|
||||||
|
tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
|
||||||
|
<https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
|
||||||
|
Zhou, Wei Li, Peter J. Liu.
|
||||||
|
|
||||||
|
After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
|
||||||
|
(see the examples for more information).
|
||||||
|
|
||||||
|
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
||||||
|
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
||||||
|
embeddings, pruning heads etc.)
|
||||||
|
|
||||||
|
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
|
||||||
|
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.EncoderDecoderConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the
|
||||||
|
model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`np.ndarray`, :obj:`tf.Tensor`, :obj:`List[tf.Tensor]` :obj:`Dict[str, tf.Tensor]` or :obj:`Dict[str, np.ndarray]` and each example must have the shape :obj:`({0})`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
decoder_input_ids (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
|
If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
||||||
|
:obj:`past_key_values`).
|
||||||
|
|
||||||
|
Provide for sequence to sequence training to the decoder. Indices can be obtained using
|
||||||
|
:class:`~transformers.PreTrainedTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||||
|
decoder_attention_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
|
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||||
|
also be used by default.
|
||||||
|
encoder_outputs (:obj:`tuple(tuple(tf.Tensor)`, `optional`):
|
||||||
|
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
|
:obj:`attentions`) :obj:`last_hidden_state` (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`) is a
|
||||||
|
tensor of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the
|
||||||
|
decoder.
|
||||||
|
past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`({0})`.
|
||||||
|
inputs_embeds (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||||
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||||
|
vectors than the model's internal embedding lookup matrix.
|
||||||
|
decoder_inputs_embeds (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
|
||||||
|
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
|
||||||
|
representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids`
|
||||||
|
indices into associated vectors than the model's internal embedding lookup matrix.
|
||||||
|
labels (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss for the decoder. Indices should be in ``[-100, 0,
|
||||||
|
..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
|
||||||
|
plain tuple.
|
||||||
|
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||||
|
behaviors between training and evaluation).
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||||
|
|
||||||
|
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
|
||||||
|
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
|
||||||
|
class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.TFEncoderDecoder` is a generic model class that will be instantiated as a transformer
|
||||||
|
architecture with one of the base model classes of the library as encoder and another one as decoder when created
|
||||||
|
with the :meth`~transformers.TFAutoModel.from_pretrained` class method for the encoder and
|
||||||
|
:meth`~transformers.TFAutoModelForCausalLM.from_pretrained` class method for the decoder.
|
||||||
|
"""
|
||||||
|
config_class = EncoderDecoderConfig
|
||||||
|
base_model_prefix = "encoder_decoder"
|
||||||
|
load_weight_prefix = "tf_encoder_decoder_model_1"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[PretrainedConfig] = None,
|
||||||
|
encoder: Optional[TFPreTrainedModel] = None,
|
||||||
|
decoder: Optional[TFPreTrainedModel] = None,
|
||||||
|
):
|
||||||
|
if config is None and (encoder is None or decoder is None):
|
||||||
|
raise ValueError("Either a configuration or an encoder and a decoder has to be provided")
|
||||||
|
if config is None:
|
||||||
|
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||||
|
else:
|
||||||
|
if not isinstance(config, self.config_class):
|
||||||
|
raise ValueError(f"config: {config} has to be of type {self.config_class}")
|
||||||
|
# initialize with config
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if encoder is None:
|
||||||
|
encoder = TFAutoModel.from_config(config.encoder, name="encoder")
|
||||||
|
|
||||||
|
if decoder is None:
|
||||||
|
decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
|
||||||
|
logger.warning(
|
||||||
|
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
|
||||||
|
)
|
||||||
|
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
|
||||||
|
logger.warning(
|
||||||
|
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure that the individual model's config refers to the shared config
|
||||||
|
# so that the updates to the config will be synced
|
||||||
|
self.encoder.config = self.config.encoder
|
||||||
|
self.decoder.config = self.config.decoder
|
||||||
|
|
||||||
|
if self.encoder.get_output_embeddings() is not None:
|
||||||
|
raise ValueError("The encoder {} should not have a LM Head. Please use a model without LM Head")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
"""
|
||||||
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
|
"""
|
||||||
|
# Add `decoder_input_ids` because `self.decoder` requires it.
|
||||||
|
input_ids = tf.constant(DUMMY_INPUTS)
|
||||||
|
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids}
|
||||||
|
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||||
|
h = tf.random.uniform(shape=shape)
|
||||||
|
dummy["encoder_hidden_states"] = h
|
||||||
|
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
def get_encoder(self):
|
||||||
|
return self.encoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.decoder
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.encoder.get_input_embeddings()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.decoder.get_output_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
return self.decoder.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently.
|
||||||
|
|
||||||
|
If there are only pytorch checkpoints for a particular encoder-decoder model, a workaround is::
|
||||||
|
|
||||||
|
>>> # a workaround to load from pytorch checkpoint
|
||||||
|
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
>>> _model.encoder.save_pretrained("./encoder")
|
||||||
|
>>> _model.decoder.save_pretrained("./decoder")
|
||||||
|
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
|
||||||
|
... )
|
||||||
|
>>> # This is only for copying some specific attributes of this particular model.
|
||||||
|
>>> model.config = _model.config
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from_pt = kwargs.pop("from_pt", False)
|
||||||
|
if from_pt:
|
||||||
|
raise ValueError(
|
||||||
|
"Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently. "
|
||||||
|
"Use a tensorflow checkpoint instead. If only the pytorch checkpoints are available, "
|
||||||
|
"create the encoder and decoder models separately, and use them to initialize `TFEncoderDecoderModel`. "
|
||||||
|
"Check `TFEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_encoder_decoder_pretrained(
|
||||||
|
cls,
|
||||||
|
encoder_pretrained_model_name_or_path: str = None,
|
||||||
|
decoder_pretrained_model_name_or_path: str = None,
|
||||||
|
*model_args,
|
||||||
|
**kwargs
|
||||||
|
) -> TFPreTrainedModel:
|
||||||
|
r"""
|
||||||
|
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
|
||||||
|
checkpoints.
|
||||||
|
|
||||||
|
|
||||||
|
Params:
|
||||||
|
encoder_pretrained_model_name_or_path (:obj: `str`, `optional`):
|
||||||
|
Information necessary to initiate the encoder. Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `pytorch index checkpoint file` (e.g, ``./pt_model/``). In this case,
|
||||||
|
``encoder_from_pt`` should be set to :obj:`True`.
|
||||||
|
|
||||||
|
decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||||
|
Information necessary to initiate the decoder. Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `pytorch checkpoint file` (e.g, ``./pt_model/``). In this case,
|
||||||
|
``decoder_from_pt`` should be set to :obj:`True`.
|
||||||
|
|
||||||
|
model_args (remaining positional arguments, `optional`):
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||||
|
|
||||||
|
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
|
:obj:`output_attentions=True`).
|
||||||
|
|
||||||
|
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter.
|
||||||
|
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter.
|
||||||
|
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||||
|
|
||||||
|
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import TFEncoderDecoderModel
|
||||||
|
>>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
|
||||||
|
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'gpt2')
|
||||||
|
>>> # saving model after fine-tuning
|
||||||
|
>>> model.save_pretrained("./bert2gpt2")
|
||||||
|
>>> # load fine-tuned model
|
||||||
|
>>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
kwargs_encoder = {
|
||||||
|
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs_decoder = {
|
||||||
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
|
||||||
|
# remove encoder, decoder kwargs from kwargs
|
||||||
|
for key in kwargs_encoder.keys():
|
||||||
|
del kwargs["encoder_" + key]
|
||||||
|
for key in kwargs_decoder.keys():
|
||||||
|
del kwargs["decoder_" + key]
|
||||||
|
|
||||||
|
# Load and initialize the encoder and decoder
|
||||||
|
# The distinction between encoder and decoder at the model level is made
|
||||||
|
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||||
|
encoder = kwargs_encoder.pop("model", None)
|
||||||
|
if encoder is None:
|
||||||
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "config" not in kwargs_encoder:
|
||||||
|
|
||||||
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
|
)
|
||||||
|
encoder_config.is_decoder = False
|
||||||
|
encoder_config.add_cross_attention = False
|
||||||
|
|
||||||
|
kwargs_encoder["config"] = encoder_config
|
||||||
|
|
||||||
|
kwargs_encoder["name"] = "encoder"
|
||||||
|
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||||
|
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||||
|
|
||||||
|
decoder = kwargs_decoder.pop("model", None)
|
||||||
|
if decoder is None:
|
||||||
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "config" not in kwargs_decoder:
|
||||||
|
|
||||||
|
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||||
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
|
logger.info(
|
||||||
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||||
|
)
|
||||||
|
decoder_config.is_decoder = True
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
|
||||||
|
kwargs_decoder["config"] = decoder_config
|
||||||
|
|
||||||
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||||
|
logger.warning(
|
||||||
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs_decoder["name"] = "decoder"
|
||||||
|
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||||
|
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|
||||||
|
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
|
||||||
|
if encoder.name != "encoder":
|
||||||
|
raise ValueError("encoder model must be created with the name `encoder`.")
|
||||||
|
if decoder.name != "decoder":
|
||||||
|
raise ValueError("decoder model must be created with the name `decoder`.")
|
||||||
|
|
||||||
|
# instantiate config with corresponding kwargs
|
||||||
|
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||||
|
return cls(encoder=encoder, decoder=decoder, config=config)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_input_ids=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
decoder_inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
training=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import TFEncoderDecoderModel, BertTokenizer
|
||||||
|
|
||||||
|
>>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
|
||||||
|
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')
|
||||||
|
|
||||||
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
|
|
||||||
|
>>> # forward
|
||||||
|
>>> input_ids = tokenizer.encode("Hello, my dog is cute", add_special_tokens=True, return_tensors='tf') # Batch size 1
|
||||||
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||||
|
|
||||||
|
>>> # training
|
||||||
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
|
||||||
|
>>> loss, logits = outputs.loss, outputs.logits
|
||||||
|
|
||||||
|
>>> # save and load from pretrained
|
||||||
|
>>> model.save_pretrained("bert2gpt2")
|
||||||
|
>>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2")
|
||||||
|
|
||||||
|
>>> # generation
|
||||||
|
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||||
|
|
||||||
|
kwargs_decoder = {
|
||||||
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoder_outputs is None:
|
||||||
|
|
||||||
|
encoder_processing_inputs = {
|
||||||
|
"func": self.encoder.call,
|
||||||
|
"config": self.encoder.config,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"inputs_embeds": inputs_embeds,
|
||||||
|
"output_attentions": output_attentions,
|
||||||
|
"output_hidden_states": output_hidden_states,
|
||||||
|
"return_dict": return_dict,
|
||||||
|
"training": training,
|
||||||
|
"kwargs_call": kwargs_encoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add arguments to encoder from `kwargs_encoder`
|
||||||
|
for k, v in kwargs_encoder.items():
|
||||||
|
encoder_processing_inputs[k] = v
|
||||||
|
kwargs_encoder = {}
|
||||||
|
|
||||||
|
encoder_inputs = input_processing(**encoder_processing_inputs)
|
||||||
|
|
||||||
|
# handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
|
||||||
|
if "decoder_input_ids" in encoder_inputs:
|
||||||
|
decoder_input_ids = encoder_inputs.pop("decoder_input_ids")
|
||||||
|
# handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
|
||||||
|
if "decoder_attention_mask" in encoder_inputs:
|
||||||
|
decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask")
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(**encoder_inputs)
|
||||||
|
|
||||||
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
decoder_processing_inputs = {
|
||||||
|
"func": self.decoder.call,
|
||||||
|
"config": self.decoder.config,
|
||||||
|
"input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": decoder_attention_mask,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"encoder_attention_mask": attention_mask,
|
||||||
|
"inputs_embeds": decoder_inputs_embeds,
|
||||||
|
"labels": labels,
|
||||||
|
"output_attentions": output_attentions,
|
||||||
|
"output_hidden_states": output_hidden_states,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"return_dict": return_dict,
|
||||||
|
"training": training,
|
||||||
|
"kwargs_call": kwargs_decoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add arguments to decoder from `kwargs_decoder`
|
||||||
|
for k, v in kwargs_decoder.items():
|
||||||
|
decoder_processing_inputs[k] = v
|
||||||
|
kwargs_decoder = {}
|
||||||
|
|
||||||
|
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||||
|
decoder_outputs = self.decoder(**decoder_inputs)
|
||||||
|
|
||||||
|
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
|
||||||
|
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
if decoder_inputs["use_cache"]:
|
||||||
|
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
|
||||||
|
# The starting index of the remaining elements in `decoder_outputs`
|
||||||
|
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||||
|
|
||||||
|
past = (encoder_outputs[0], past_key_values) if past_key_values else None
|
||||||
|
|
||||||
|
if not decoder_inputs["return_dict"]:
|
||||||
|
if not isinstance(encoder_outputs, tuple):
|
||||||
|
encoder_outputs = encoder_outputs.to_tuple()
|
||||||
|
output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs
|
||||||
|
output = tuple([x for x in output if x is not None])
|
||||||
|
return output
|
||||||
|
|
||||||
|
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||||
|
if not isinstance(encoder_outputs, TFBaseModelOutput):
|
||||||
|
encoder_outputs = TFBaseModelOutput(
|
||||||
|
last_hidden_state=encoder_outputs[0],
|
||||||
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||||
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TFSeq2SeqLMOutput(
|
||||||
|
loss=decoder_outputs.loss,
|
||||||
|
logits=decoder_outputs.logits,
|
||||||
|
past_key_values=past,
|
||||||
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||||
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def serving_output(self, output):
|
||||||
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = (
|
||||||
|
tf.convert_to_tensor(output.cross_attentions)
|
||||||
|
if self.config.output_attentions and output.cross_attentions is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return TFSeq2SeqLMOutput(
|
||||||
|
logits=output.logits,
|
||||||
|
past_key_values=pkv,
|
||||||
|
decoder_hidden_states=dec_hs,
|
||||||
|
decoder_attentions=dec_attns,
|
||||||
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
|
encoder_hidden_states=enc_hs,
|
||||||
|
encoder_attentions=enc_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past,
|
||||||
|
attention_mask,
|
||||||
|
use_cache=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if past is None or len(past) not in {1, 2}:
|
||||||
|
raise ValueError(f"past has to be an iterable of length 1,2 got {past}")
|
||||||
|
|
||||||
|
if len(past) == 1:
|
||||||
|
if not isinstance(past[0], tf.Tensor):
|
||||||
|
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
|
||||||
|
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||||
|
past_key_values = None
|
||||||
|
else:
|
||||||
|
if len(past) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||||
|
)
|
||||||
|
encoder_outputs, past_key_values = past
|
||||||
|
if isinstance(encoder_outputs, tuple):
|
||||||
|
if not isinstance(encoder_outputs[0], tf.Tensor):
|
||||||
|
raise ValueError(
|
||||||
|
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||||
|
)
|
||||||
|
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||||
|
elif isinstance(encoder_outputs, tf.Tensor):
|
||||||
|
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||||
|
if not past_key_values:
|
||||||
|
raise ValueError(
|
||||||
|
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||||
|
)
|
||||||
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
|
|
||||||
|
if not isinstance(encoder_outputs, TFBaseModelOutput):
|
||||||
|
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"encoder_outputs": encoder_outputs,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
|
}
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
|
||||||
|
"Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
# apply decoder cache reordering here
|
||||||
|
if len(past) == 1:
|
||||||
|
return past
|
||||||
|
|
||||||
|
encoder_outputs, past_key_values = past
|
||||||
|
|
||||||
|
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
|
||||||
@@ -24,8 +24,8 @@ import tensorflow as tf
|
|||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFBaseModelOutputWithPooling,
|
TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
TFMaskedLMOutput,
|
TFMaskedLMOutput,
|
||||||
TFSequenceClassifierOutput,
|
TFSequenceClassifierOutput,
|
||||||
TFTokenClassifierOutput,
|
TFTokenClassifierOutput,
|
||||||
@@ -216,6 +216,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
@@ -228,16 +230,49 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(inputs=hidden_states)
|
mixed_query_layer = self.query(inputs=hidden_states)
|
||||||
mixed_key_layer = self.key(inputs=hidden_states)
|
|
||||||
mixed_value_layer = self.value(inputs=hidden_states)
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
@@ -267,6 +302,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -305,6 +342,9 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
|
|||||||
input_tensor: tf.Tensor,
|
input_tensor: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
@@ -312,13 +352,17 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output(
|
attention_output = self.dense_output(
|
||||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||||
)
|
)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
# add attentions (possibly with past_key_value) if we output them
|
||||||
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -369,6 +413,12 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.attention = TFLayoutLMAttention(config, name="attention")
|
self.attention = TFLayoutLMAttention(config, name="attention")
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
if not self.is_decoder:
|
||||||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||||
|
self.crossattention = TFLayoutLMAttention(config, name="crossattention")
|
||||||
self.intermediate = TFLayoutLMIntermediate(config, name="intermediate")
|
self.intermediate = TFLayoutLMIntermediate(config, name="intermediate")
|
||||||
self.bert_output = TFLayoutLMOutput(config, name="output")
|
self.bert_output = TFLayoutLMOutput(config, name="output")
|
||||||
|
|
||||||
@@ -377,22 +427,69 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
attention_outputs = self.attention(
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
input_tensor=hidden_states,
|
input_tensor=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "crossattention"):
|
||||||
|
raise ValueError(
|
||||||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||||
|
"by setting `config.add_cross_attention=True`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
input_tensor=attention_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=cross_attn_past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||||
layer_output = self.bert_output(
|
layer_output = self.bert_output(
|
||||||
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -401,7 +498,7 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
|
|||||||
class TFLayoutLMEncoder(tf.keras.layers.Layer):
|
class TFLayoutLMEncoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: LayoutLMConfig, **kwargs):
|
def __init__(self, config: LayoutLMConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -409,39 +506,61 @@ class TFLayoutLMEncoder(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||||
|
use_cache: Optional[bool],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
output_hidden_states: bool,
|
output_hidden_states: bool,
|
||||||
return_dict: bool,
|
return_dict: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -585,12 +704,14 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -665,6 +786,11 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=embedding_output,
|
hidden_states=embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
|
# Need to pass these required positional arguments to `Encoder`
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=False,
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -680,11 +806,12 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
|
|||||||
pooled_output,
|
pooled_output,
|
||||||
) + encoder_outputs[1:]
|
) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -802,7 +929,9 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
|
|||||||
self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
|
self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(
|
||||||
|
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[TFModelInputType] = None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
@@ -812,12 +941,14 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -859,6 +990,8 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -881,15 +1014,25 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
||||||
|
def serving_output(
|
||||||
|
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=output.last_hidden_state,
|
last_hidden_state=output.last_hidden_state,
|
||||||
pooler_output=output.pooler_output,
|
pooler_output=output.pooler_output,
|
||||||
|
past_key_values=pkv,
|
||||||
hidden_states=hs,
|
hidden_states=hs,
|
||||||
attentions=attns,
|
attentions=attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -510,7 +510,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(self, input_ids):
|
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
|
||||||
"""
|
"""
|
||||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
||||||
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
|
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||||
@@ -520,11 +520,19 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
|||||||
Returns: tf.Tensor
|
Returns: tf.Tensor
|
||||||
"""
|
"""
|
||||||
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
|
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
|
||||||
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
|
incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
|
||||||
|
|
||||||
return incremental_indices + self.padding_idx
|
return incremental_indices + self.padding_idx
|
||||||
|
|
||||||
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
past_key_values_length=0,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Applies embedding based on inputs tensor.
|
Applies embedding based on inputs tensor.
|
||||||
|
|
||||||
@@ -544,7 +552,9 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||||
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
|
position_ids = self.create_position_ids_from_input_ids(
|
||||||
|
input_ids=input_ids, past_key_values_length=past_key_values_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
position_ids = tf.expand_dims(
|
position_ids = tf.expand_dims(
|
||||||
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
||||||
|
|||||||
@@ -982,7 +982,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
|||||||
@@ -729,7 +729,6 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
|||||||
@@ -673,7 +673,6 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
|
|||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
|
||||||
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
|||||||
@@ -23,15 +23,16 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
DUMMY_INPUTS,
|
||||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFBaseModelOutputWithPooling,
|
TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
TFCausalLMOutput,
|
TFCausalLMOutputWithCrossAttentions,
|
||||||
TFMaskedLMOutput,
|
TFMaskedLMOutput,
|
||||||
TFMultipleChoiceModelOutput,
|
TFMultipleChoiceModelOutput,
|
||||||
TFQuestionAnsweringModelOutput,
|
TFQuestionAnsweringModelOutput,
|
||||||
@@ -112,6 +113,7 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
position_ids: tf.Tensor = None,
|
position_ids: tf.Tensor = None,
|
||||||
token_type_ids: tf.Tensor = None,
|
token_type_ids: tf.Tensor = None,
|
||||||
inputs_embeds: tf.Tensor = None,
|
inputs_embeds: tf.Tensor = None,
|
||||||
|
past_key_values_length=0,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -131,7 +133,9 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(
|
||||||
|
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||||
@@ -170,6 +174,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
@@ -182,16 +188,49 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(inputs=hidden_states)
|
mixed_query_layer = self.query(inputs=hidden_states)
|
||||||
mixed_key_layer = self.key(inputs=hidden_states)
|
|
||||||
mixed_value_layer = self.value(inputs=hidden_states)
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
@@ -221,6 +260,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -259,6 +300,9 @@ class TFRemBertAttention(tf.keras.layers.Layer):
|
|||||||
input_tensor: tf.Tensor,
|
input_tensor: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
@@ -266,13 +310,17 @@ class TFRemBertAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output(
|
attention_output = self.dense_output(
|
||||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||||
)
|
)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
# add attentions (possibly with past_key_value) if we output them
|
||||||
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -323,6 +371,12 @@ class TFRemBertLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.attention = TFRemBertAttention(config, name="attention")
|
self.attention = TFRemBertAttention(config, name="attention")
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
if not self.is_decoder:
|
||||||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||||
|
self.crossattention = TFRemBertAttention(config, name="crossattention")
|
||||||
self.intermediate = TFRemBertIntermediate(config, name="intermediate")
|
self.intermediate = TFRemBertIntermediate(config, name="intermediate")
|
||||||
self.bert_output = TFRemBertOutput(config, name="output")
|
self.bert_output = TFRemBertOutput(config, name="output")
|
||||||
|
|
||||||
@@ -331,22 +385,69 @@ class TFRemBertLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
attention_outputs = self.attention(
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
input_tensor=hidden_states,
|
input_tensor=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "crossattention"):
|
||||||
|
raise ValueError(
|
||||||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||||
|
"by setting `config.add_cross_attention=True`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
input_tensor=attention_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=cross_attn_past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||||
layer_output = self.bert_output(
|
layer_output = self.bert_output(
|
||||||
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -354,6 +455,7 @@ class TFRemBertLayer(tf.keras.layers.Layer):
|
|||||||
class TFRemBertEncoder(tf.keras.layers.Layer):
|
class TFRemBertEncoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: RemBertConfig, **kwargs):
|
def __init__(self, config: RemBertConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.embedding_hidden_mapping_in = tf.keras.layers.Dense(
|
self.embedding_hidden_mapping_in = tf.keras.layers.Dense(
|
||||||
units=config.hidden_size,
|
units=config.hidden_size,
|
||||||
@@ -367,40 +469,62 @@ class TFRemBertEncoder(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_values: Tuple[Tuple[tf.Tensor]],
|
||||||
|
use_cache: bool,
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
output_hidden_states: bool,
|
output_hidden_states: bool,
|
||||||
return_dict: bool,
|
return_dict: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
|
hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
|
||||||
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -500,6 +624,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
self.embeddings = TFRemBertEmbeddings(config, name="embeddings")
|
self.embeddings = TFRemBertEmbeddings(config, name="embeddings")
|
||||||
self.encoder = TFRemBertEncoder(config, name="encoder")
|
self.encoder = TFRemBertEncoder(config, name="encoder")
|
||||||
@@ -519,6 +644,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[TFModelInputType] = None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
@@ -527,12 +653,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -542,6 +672,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -549,6 +683,9 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.is_decoder:
|
||||||
|
inputs["use_cache"] = False
|
||||||
|
|
||||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif inputs["input_ids"] is not None:
|
elif inputs["input_ids"] is not None:
|
||||||
@@ -558,8 +695,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
if inputs["past_key_values"] is None:
|
||||||
|
past_key_values_length = 0
|
||||||
|
inputs["past_key_values"] = [None] * len(self.encoder.layer)
|
||||||
|
else:
|
||||||
|
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||||
|
|
||||||
if inputs["attention_mask"] is None:
|
if inputs["attention_mask"] is None:
|
||||||
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
|
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||||
|
|
||||||
if inputs["token_type_ids"] is None:
|
if inputs["token_type_ids"] is None:
|
||||||
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
||||||
@@ -569,6 +714,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
token_type_ids=inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -577,7 +723,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||||
|
|
||||||
|
mask_seq_length = seq_length + past_key_values_length
|
||||||
|
# Copied from `modeling_tf_t5.py`
|
||||||
|
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
if self.is_decoder:
|
||||||
|
seq_ids = tf.range(mask_seq_length)
|
||||||
|
causal_mask = tf.less_equal(
|
||||||
|
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||||
|
seq_ids[None, :, None],
|
||||||
|
)
|
||||||
|
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
|
||||||
|
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
|
||||||
|
attention_mask_shape = shape_list(extended_attention_mask)
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -589,6 +757,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
|
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
||||||
|
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
inputs["encoder_attention_mask"] = tf.cast(
|
||||||
|
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||||
|
)
|
||||||
|
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||||
|
if num_dims_encoder_attention_mask == 3:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||||
|
if num_dims_encoder_attention_mask == 2:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||||
|
|
||||||
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
|
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||||
|
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
@@ -603,6 +794,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=embedding_output,
|
hidden_states=embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -613,13 +808,18 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
|
|||||||
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
|
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
return (
|
||||||
|
sequence_output,
|
||||||
|
pooled_output,
|
||||||
|
) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -632,6 +832,24 @@ class TFRemBertPreTrainedModel(TFPreTrainedModel):
|
|||||||
config_class = RemBertConfig
|
config_class = RemBertConfig
|
||||||
base_model_prefix = "rembert"
|
base_model_prefix = "rembert"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
"""
|
||||||
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
|
"""
|
||||||
|
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||||
|
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||||
|
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||||
|
h = tf.random.uniform(shape=shape)
|
||||||
|
dummy["encoder_hidden_states"] = h
|
||||||
|
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
|
||||||
REMBERT_START_DOCSTRING = r"""
|
REMBERT_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@@ -740,7 +958,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="rembert",
|
checkpoint="rembert",
|
||||||
output_type=TFBaseModelOutputWithPooling,
|
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -751,12 +969,36 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -766,6 +1008,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -779,6 +1025,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -787,15 +1037,25 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
||||||
|
def serving_output(
|
||||||
|
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=output.last_hidden_state,
|
last_hidden_state=output.last_hidden_state,
|
||||||
pooler_output=output.pooler_output,
|
pooler_output=output.pooler_output,
|
||||||
|
past_key_values=pkv,
|
||||||
hidden_states=hs,
|
hidden_states=hs,
|
||||||
attentions=attns,
|
attentions=attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -912,10 +1172,23 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||||
return self.mlm.predictions
|
return self.mlm.predictions
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past:
|
||||||
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": inputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": model_kwargs["use_cache"],
|
||||||
|
}
|
||||||
|
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="rembert",
|
checkpoint="rembert",
|
||||||
output_type=TFCausalLMOutput,
|
output_type=TFCausalLMOutputWithCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -926,14 +1199,36 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||||
config.vocab_size - 1]``.
|
config.vocab_size - 1]``.
|
||||||
@@ -947,6 +1242,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -961,6 +1260,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -980,18 +1283,28 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return TFCausalLMOutput(
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||||
|
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ if is_torch_available():
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_roberta"] = [
|
_import_structure["modeling_tf_roberta"] = [
|
||||||
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFRobertaForCausalLM",
|
||||||
"TFRobertaForMaskedLM",
|
"TFRobertaForMaskedLM",
|
||||||
"TFRobertaForMultipleChoice",
|
"TFRobertaForMultipleChoice",
|
||||||
"TFRobertaForQuestionAnswering",
|
"TFRobertaForQuestionAnswering",
|
||||||
@@ -90,6 +91,7 @@ if TYPE_CHECKING:
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_roberta import (
|
from .modeling_tf_roberta import (
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForMultipleChoice,
|
TFRobertaForMultipleChoice,
|
||||||
TFRobertaForQuestionAnswering,
|
TFRobertaForQuestionAnswering,
|
||||||
|
|||||||
@@ -24,14 +24,16 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
DUMMY_INPUTS,
|
||||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFBaseModelOutputWithPooling,
|
TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
TFCausalLMOutputWithCrossAttentions,
|
||||||
TFMaskedLMOutput,
|
TFMaskedLMOutput,
|
||||||
TFMultipleChoiceModelOutput,
|
TFMultipleChoiceModelOutput,
|
||||||
TFQuestionAnsweringModelOutput,
|
TFQuestionAnsweringModelOutput,
|
||||||
@@ -39,6 +41,7 @@ from ...modeling_tf_outputs import (
|
|||||||
TFTokenClassifierOutput,
|
TFTokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFMaskedLanguageModelingLoss,
|
TFMaskedLanguageModelingLoss,
|
||||||
TFModelInputType,
|
TFModelInputType,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
@@ -112,7 +115,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(self, input_ids):
|
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
|
||||||
"""
|
"""
|
||||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
||||||
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
|
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||||
@@ -122,11 +125,19 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
|
|||||||
Returns: tf.Tensor
|
Returns: tf.Tensor
|
||||||
"""
|
"""
|
||||||
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
|
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
|
||||||
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
|
incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
|
||||||
|
|
||||||
return incremental_indices + self.padding_idx
|
return incremental_indices + self.padding_idx
|
||||||
|
|
||||||
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
past_key_values_length=0,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Applies embedding based on inputs tensor.
|
Applies embedding based on inputs tensor.
|
||||||
|
|
||||||
@@ -146,7 +157,9 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||||
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
|
position_ids = self.create_position_ids_from_input_ids(
|
||||||
|
input_ids=input_ids, past_key_values_length=past_key_values_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
position_ids = tf.expand_dims(
|
position_ids = tf.expand_dims(
|
||||||
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
||||||
@@ -210,6 +223,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
@@ -222,16 +237,49 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(inputs=hidden_states)
|
mixed_query_layer = self.query(inputs=hidden_states)
|
||||||
mixed_key_layer = self.key(inputs=hidden_states)
|
|
||||||
mixed_value_layer = self.value(inputs=hidden_states)
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
@@ -261,6 +309,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -299,6 +349,9 @@ class TFRobertaAttention(tf.keras.layers.Layer):
|
|||||||
input_tensor: tf.Tensor,
|
input_tensor: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
@@ -306,13 +359,17 @@ class TFRobertaAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output(
|
attention_output = self.dense_output(
|
||||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||||
)
|
)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
# add attentions (possibly with past_key_value) if we output them
|
||||||
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -363,6 +420,12 @@ class TFRobertaLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.attention = TFRobertaAttention(config, name="attention")
|
self.attention = TFRobertaAttention(config, name="attention")
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
if not self.is_decoder:
|
||||||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||||
|
self.crossattention = TFRobertaAttention(config, name="crossattention")
|
||||||
self.intermediate = TFRobertaIntermediate(config, name="intermediate")
|
self.intermediate = TFRobertaIntermediate(config, name="intermediate")
|
||||||
self.bert_output = TFRobertaOutput(config, name="output")
|
self.bert_output = TFRobertaOutput(config, name="output")
|
||||||
|
|
||||||
@@ -371,22 +434,69 @@ class TFRobertaLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
attention_outputs = self.attention(
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
input_tensor=hidden_states,
|
input_tensor=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "crossattention"):
|
||||||
|
raise ValueError(
|
||||||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||||
|
"by setting `config.add_cross_attention=True`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
input_tensor=attention_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=cross_attn_past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||||
layer_output = self.bert_output(
|
layer_output = self.bert_output(
|
||||||
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -395,7 +505,7 @@ class TFRobertaLayer(tf.keras.layers.Layer):
|
|||||||
class TFRobertaEncoder(tf.keras.layers.Layer):
|
class TFRobertaEncoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: RobertaConfig, **kwargs):
|
def __init__(self, config: RobertaConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -403,39 +513,61 @@ class TFRobertaEncoder(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||||
|
use_cache: Optional[bool],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
output_hidden_states: bool,
|
output_hidden_states: bool,
|
||||||
return_dict: bool,
|
return_dict: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -447,6 +579,8 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
self.initializer_range = config.initializer_range
|
self.initializer_range = config.initializer_range
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
@@ -483,12 +617,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -498,6 +636,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -505,6 +647,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.is_decoder:
|
||||||
|
inputs["use_cache"] = False
|
||||||
|
|
||||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif inputs["input_ids"] is not None:
|
elif inputs["input_ids"] is not None:
|
||||||
@@ -514,8 +659,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
if inputs["past_key_values"] is None:
|
||||||
|
past_key_values_length = 0
|
||||||
|
inputs["past_key_values"] = [None] * len(self.encoder.layer)
|
||||||
|
else:
|
||||||
|
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||||
|
|
||||||
if inputs["attention_mask"] is None:
|
if inputs["attention_mask"] is None:
|
||||||
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
|
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||||
|
|
||||||
if inputs["token_type_ids"] is None:
|
if inputs["token_type_ids"] is None:
|
||||||
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
||||||
@@ -525,6 +678,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
token_type_ids=inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -533,7 +687,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||||
|
|
||||||
|
mask_seq_length = seq_length + past_key_values_length
|
||||||
|
# Copied from `modeling_tf_t5.py`
|
||||||
|
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
if self.is_decoder:
|
||||||
|
seq_ids = tf.range(mask_seq_length)
|
||||||
|
causal_mask = tf.less_equal(
|
||||||
|
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||||
|
seq_ids[None, :, None],
|
||||||
|
)
|
||||||
|
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
|
||||||
|
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
|
||||||
|
attention_mask_shape = shape_list(extended_attention_mask)
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -545,6 +721,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
|
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
||||||
|
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
inputs["encoder_attention_mask"] = tf.cast(
|
||||||
|
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||||
|
)
|
||||||
|
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||||
|
if num_dims_encoder_attention_mask == 3:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||||
|
if num_dims_encoder_attention_mask == 2:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||||
|
|
||||||
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
|
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||||
|
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
@@ -559,6 +758,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=embedding_output,
|
hidden_states=embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -574,11 +777,13 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||||||
pooled_output,
|
pooled_output,
|
||||||
) + encoder_outputs[1:]
|
) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -591,6 +796,25 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
|||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel.dummy_inputs
|
||||||
|
def dummy_inputs(self):
|
||||||
|
"""
|
||||||
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
|
"""
|
||||||
|
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||||
|
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||||
|
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||||
|
h = tf.random.uniform(shape=shape)
|
||||||
|
dummy["encoder_hidden_states"] = h
|
||||||
|
|
||||||
|
return dummy
|
||||||
|
|
||||||
@tf.function(
|
@tf.function(
|
||||||
input_signature=[
|
input_signature=[
|
||||||
{
|
{
|
||||||
@@ -711,7 +935,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TFBaseModelOutputWithPooling,
|
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -722,12 +946,36 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
training=False,
|
training=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -737,6 +985,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -750,6 +1002,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -759,15 +1015,24 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
||||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
def serving_output(
|
||||||
|
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFBaseModelOutputWithPooling(
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=output.last_hidden_state,
|
last_hidden_state=output.last_hidden_state,
|
||||||
pooler_output=output.pooler_output,
|
pooler_output=output.pooler_output,
|
||||||
|
past_key_values=pkv,
|
||||||
hidden_states=hs,
|
hidden_states=hs,
|
||||||
attentions=attns,
|
attentions=attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -922,6 +1187,163 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
|
|||||||
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||||
|
|
||||||
|
|
||||||
|
class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
|
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config: RobertaConfig, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
|
if not config.is_decoder:
|
||||||
|
logger.warning("If you want to use `TFRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
|
||||||
|
|
||||||
|
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
|
||||||
|
self.lm_head = TFRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head")
|
||||||
|
|
||||||
|
def get_lm_head(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def get_prefix_bias_name(self):
|
||||||
|
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||||
|
return self.name + "/" + self.lm_head.name
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past:
|
||||||
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": inputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": model_kwargs["use_cache"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=TFCausalLMOutputWithCrossAttentions,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
|
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||||
|
config.vocab_size - 1]``.
|
||||||
|
"""
|
||||||
|
inputs = input_processing(
|
||||||
|
func=self.call,
|
||||||
|
config=self.config,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
labels=labels,
|
||||||
|
training=training,
|
||||||
|
kwargs_call=kwargs,
|
||||||
|
)
|
||||||
|
outputs = self.roberta(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
token_type_ids=inputs["token_type_ids"],
|
||||||
|
position_ids=inputs["position_ids"],
|
||||||
|
head_mask=inputs["head_mask"],
|
||||||
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
|
output_attentions=inputs["output_attentions"],
|
||||||
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
|
training=inputs["training"],
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states=sequence_output)
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if inputs["labels"] is not None:
|
||||||
|
# shift labels to the left and cut last logit token
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
labels = inputs["labels"][:, 1:]
|
||||||
|
loss = self.compute_loss(labels=labels, logits=logits)
|
||||||
|
|
||||||
|
if not inputs["return_dict"]:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||||
|
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
from ...file_utils import add_start_docstrings
|
from ...file_utils import add_start_docstrings
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.modeling_tf_roberta import (
|
from ..roberta.modeling_tf_roberta import (
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForMultipleChoice,
|
TFRobertaForMultipleChoice,
|
||||||
TFRobertaForQuestionAnswering,
|
TFRobertaForQuestionAnswering,
|
||||||
@@ -85,6 +86,19 @@ class TFXLMRobertaModel(TFRobertaModel):
|
|||||||
config_class = XLMRobertaConfig
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class XLMRobertaForCausalLM(TFRobertaForCausalLM):
|
||||||
|
"""
|
||||||
|
This class overrides :class:`~transformers.TFRobertaForCausalLM`. Please check the superclass for the appropriate
|
||||||
|
documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""XLM-RoBERTa Model with a `language modeling` head on top. """,
|
"""XLM-RoBERTa Model with a `language modeling` head on top. """,
|
||||||
XLM_ROBERTA_START_DOCSTRING,
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
|||||||
@@ -929,6 +929,15 @@ class TFElectraPreTrainedModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFEncoderDecoderModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1712,6 +1721,15 @@ class TFRemBertPreTrainedModel:
|
|||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class TFRobertaForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaForMaskedLM:
|
class TFRobertaForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|||||||
@@ -24,15 +24,16 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
DUMMY_INPUTS,
|
||||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFBaseModelOutputWithPooling,
|
TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
TFCausalLMOutput,
|
TFCausalLMOutputWithCrossAttentions,
|
||||||
TFMaskedLMOutput,
|
TFMaskedLMOutput,
|
||||||
TFMultipleChoiceModelOutput,
|
TFMultipleChoiceModelOutput,
|
||||||
TFQuestionAnsweringModelOutput,
|
TFQuestionAnsweringModelOutput,
|
||||||
@@ -116,6 +117,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
|||||||
position_ids: tf.Tensor = None,
|
position_ids: tf.Tensor = None,
|
||||||
token_type_ids: tf.Tensor = None,
|
token_type_ids: tf.Tensor = None,
|
||||||
inputs_embeds: tf.Tensor = None,
|
inputs_embeds: tf.Tensor = None,
|
||||||
|
past_key_values_length=0,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -135,7 +137,9 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
|||||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
position_ids = tf.expand_dims(
|
||||||
|
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||||
@@ -174,6 +178,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
@@ -186,16 +192,49 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(inputs=hidden_states)
|
mixed_query_layer = self.query(inputs=hidden_states)
|
||||||
mixed_key_layer = self.key(inputs=hidden_states)
|
|
||||||
mixed_value_layer = self.value(inputs=hidden_states)
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||||
|
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
@@ -225,6 +264,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
|||||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -263,6 +304,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
input_tensor: tf.Tensor,
|
input_tensor: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: tf.Tensor,
|
||||||
|
encoder_attention_mask: tf.Tensor,
|
||||||
|
past_key_value: Tuple[tf.Tensor],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
@@ -270,13 +314,17 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output(
|
attention_output = self.dense_output(
|
||||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||||
)
|
)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
# add attentions (possibly with past_key_value) if we output them
|
||||||
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -327,6 +375,12 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention")
|
self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention")
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
if not self.is_decoder:
|
||||||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||||
|
self.crossattention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="crossattention")
|
||||||
self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate")
|
self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate")
|
||||||
self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output")
|
self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output")
|
||||||
|
|
||||||
@@ -335,20 +389,69 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor]:
|
) -> Tuple[tf.Tensor]:
|
||||||
attention_outputs = self.attention(
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
input_tensor=hidden_states,
|
input_tensor=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "crossattention"):
|
||||||
|
raise ValueError(
|
||||||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||||
|
"by setting `config.add_cross_attention=True`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
input_tensor=attention_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=cross_attn_past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||||
layer_output = self.bert_output(hidden_states=intermediate_output, input_tensor=attention_output, training=training)
|
layer_output = self.bert_output(
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||||
|
)
|
||||||
|
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -357,7 +460,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
|||||||
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -365,39 +468,61 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: tf.Tensor,
|
attention_mask: tf.Tensor,
|
||||||
head_mask: tf.Tensor,
|
head_mask: tf.Tensor,
|
||||||
|
encoder_hidden_states: Optional[tf.Tensor],
|
||||||
|
encoder_attention_mask: Optional[tf.Tensor],
|
||||||
|
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||||
|
use_cache: Optional[bool],
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
output_hidden_states: bool,
|
output_hidden_states: bool,
|
||||||
return_dict: bool,
|
return_dict: bool,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -492,6 +617,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings")
|
self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings")
|
||||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder")
|
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder")
|
||||||
@@ -521,12 +647,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -536,6 +666,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -543,6 +677,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.is_decoder:
|
||||||
|
inputs["use_cache"] = False
|
||||||
|
|
||||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif inputs["input_ids"] is not None:
|
elif inputs["input_ids"] is not None:
|
||||||
@@ -552,8 +689,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
if inputs["past_key_values"] is None:
|
||||||
|
past_key_values_length = 0
|
||||||
|
inputs["past_key_values"] = [None] * len(self.encoder.layer)
|
||||||
|
else:
|
||||||
|
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||||
|
|
||||||
if inputs["attention_mask"] is None:
|
if inputs["attention_mask"] is None:
|
||||||
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
|
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||||
|
|
||||||
if inputs["token_type_ids"] is None:
|
if inputs["token_type_ids"] is None:
|
||||||
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
||||||
@@ -563,6 +708,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
token_type_ids=inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -571,7 +717,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||||
|
|
||||||
|
mask_seq_length = seq_length + past_key_values_length
|
||||||
|
# Copied from `modeling_tf_t5.py`
|
||||||
|
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
if self.is_decoder:
|
||||||
|
seq_ids = tf.range(mask_seq_length)
|
||||||
|
causal_mask = tf.less_equal(
|
||||||
|
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||||
|
seq_ids[None, :, None],
|
||||||
|
)
|
||||||
|
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
|
||||||
|
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
|
||||||
|
attention_mask_shape = shape_list(extended_attention_mask)
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = tf.reshape(
|
||||||
|
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -583,6 +751,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
|
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
||||||
|
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
inputs["encoder_attention_mask"] = tf.cast(
|
||||||
|
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||||
|
)
|
||||||
|
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||||
|
if num_dims_encoder_attention_mask == 3:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||||
|
if num_dims_encoder_attention_mask == 2:
|
||||||
|
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||||
|
|
||||||
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
|
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||||
|
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
@@ -597,6 +788,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=embedding_output,
|
hidden_states=embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -610,10 +805,12 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
sequence_output,
|
sequence_output,
|
||||||
) + encoder_outputs[1:]
|
) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -625,6 +822,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
"""
|
||||||
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
|
"""
|
||||||
|
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||||
|
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||||
|
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||||
|
h = tf.random.uniform(shape=shape)
|
||||||
|
dummy["encoder_hidden_states"] = h
|
||||||
|
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
|
||||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@@ -732,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TFBaseModelOutputWithPooling,
|
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -743,12 +958,36 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -758,6 +997,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -771,6 +1014,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -779,12 +1026,26 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
||||||
def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput:
|
def serving_output(
|
||||||
|
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
|
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=output.last_hidden_state,
|
||||||
|
pooler_output=output.pooler_output,
|
||||||
|
past_key_values=pkv,
|
||||||
|
hidden_states=hs,
|
||||||
|
attentions=attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
|
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
|
||||||
@@ -903,10 +1164,22 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||||
return self.mlm.predictions
|
return self.mlm.predictions
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past:
|
||||||
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": inputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": model_kwargs["use_cache"],
|
||||||
|
}
|
||||||
|
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TFCausalLMOutput,
|
output_type=TFCausalLMOutputWithCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
@@ -917,14 +1190,36 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
training: Optional[bool] = False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||||
|
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||||
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||||
config.vocab_size - 1]``.
|
config.vocab_size - 1]``.
|
||||||
@@ -938,6 +1233,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -952,6 +1251,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
past_key_values=inputs["past_key_values"],
|
||||||
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=inputs["return_dict"],
|
||||||
@@ -971,19 +1274,28 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return TFCausalLMOutput(
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||||
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
|
||||||
|
output_cache = self.config.use_cache and self.config.is_decoder
|
||||||
|
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||||
|
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||||
|
cross_attns = None
|
||||||
|
|
||||||
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
|
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co
|
|||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -123,6 +123,33 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
|||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_model(
|
def create_and_check_model(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from transformers.models.auto import get_values
|
|||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -125,6 +125,33 @@ class TFBertModelTester:
|
|||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_bert_model(
|
def create_and_check_bert_model(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1393,6 +1393,22 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
|
||||||
|
"""Creates a random float32 tensor"""
|
||||||
|
if rng is None:
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
|
total_dims = 1
|
||||||
|
for dim in shape:
|
||||||
|
total_dims *= dim
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for _ in range(total_dims):
|
||||||
|
values.append(rng.random() * scale)
|
||||||
|
|
||||||
|
return tf.reshape(tf.constant(values, dtype=dtype if dtype is not None else tf.float32), shape=shape)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class UtilsFunctionsTest(unittest.TestCase):
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
|||||||
765
tests/test_modeling_tf_encoder_decoder.py
Normal file
765
tests/test_modeling_tf_encoder_decoder.py
Normal file
@@ -0,0 +1,765 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import is_tf_available, is_torch_available
|
||||||
|
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_modeling_tf_bert import TFBertModelTester
|
||||||
|
from .test_modeling_tf_common import ids_tensor
|
||||||
|
from .test_modeling_tf_rembert import TFRemBertModelTester
|
||||||
|
from .test_modeling_tf_roberta import TFRobertaModelTester
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
EncoderDecoderConfig,
|
||||||
|
TFAutoModel,
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
|
TFBertLMHeadModel,
|
||||||
|
TFBertModel,
|
||||||
|
TFEncoderDecoderModel,
|
||||||
|
TFRemBertForCausalLM,
|
||||||
|
TFRemBertModel,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
|
TFRobertaModel,
|
||||||
|
)
|
||||||
|
from transformers.modeling_tf_outputs import TFBaseModelOutput
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import BertLMHeadModel, BertModel, EncoderDecoderModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFEncoderDecoderMixin:
|
||||||
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_pretrained_model(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_from_pretrained_configs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||||
|
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
|
||||||
|
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||||
|
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||||
|
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
|
||||||
|
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||||
|
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_hidden_states)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=None,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_from_pretrained(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
return_dict,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||||
|
enc_dec_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_save_and_load(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
outputs = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
out_2 = np.array(outputs[0])
|
||||||
|
out_2[np.isnan(out_2)] = 0
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
enc_dec_model.save_pretrained(tmpdirname)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
after_outputs = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
out_1 = np.array(after_outputs[0])
|
||||||
|
out_1[np.isnan(out_1)] = 0
|
||||||
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_labels(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
labels,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure `loss` exist
|
||||||
|
assert "loss" in outputs_encoder_decoder
|
||||||
|
|
||||||
|
batch_size, seq_len = decoder_input_ids.shape
|
||||||
|
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size)
|
||||||
|
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_output_attentions(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||||
|
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||||
|
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_attentions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||||
|
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||||
|
num_decoder_layers = (
|
||||||
|
decoder_config.num_decoder_layers
|
||||||
|
if hasattr(decoder_config, "num_decoder_layers")
|
||||||
|
else decoder_config.num_hidden_layers
|
||||||
|
)
|
||||||
|
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
decoder_attentions[0].shape[-3:],
|
||||||
|
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||||
|
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||||
|
|
||||||
|
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
|
||||||
|
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cross_attentions[0].shape[-3:],
|
||||||
|
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
|
generated_output = enc_dec_model.generate(
|
||||||
|
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||||
|
)
|
||||||
|
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
|
def test_encoder_decoder_model(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_from_pretrained(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_from_pretrained_return_dict(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
||||||
|
|
||||||
|
def test_save_and_load_from_pretrained(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_save_and_load(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_labels(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_labels(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_output_attentions(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_generate(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
|
model_2 = self.get_pretrained_model()
|
||||||
|
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
|
||||||
|
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||||
|
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||||
|
|
||||||
|
outputs = model_2(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
out_2 = np.array(outputs[0])
|
||||||
|
out_2[np.isnan(out_2)] = 0
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
model_2.save_pretrained(tmp_dirname)
|
||||||
|
model_1 = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
|
|
||||||
|
after_outputs = model_1(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
out_1 = np.array(after_outputs[0])
|
||||||
|
out_1[np.isnan(out_1)] = 0
|
||||||
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||||
|
def get_pretrained_model(self):
|
||||||
|
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||||
|
|
||||||
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
encoder_model = TFBertModel(config, name="encoder")
|
||||||
|
decoder_model = TFBertLMHeadModel(decoder_config, name="decoder")
|
||||||
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
model_tester_encoder = TFBertModelTester(self, batch_size=13)
|
||||||
|
model_tester_decoder = TFBertModelTester(self, batch_size=13)
|
||||||
|
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||||
|
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
attention_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = encoder_config_and_inputs
|
||||||
|
(
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_token_type_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
decoder_sequence_labels,
|
||||||
|
decoder_token_labels,
|
||||||
|
decoder_choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = decoder_config_and_inputs
|
||||||
|
|
||||||
|
# make sure that cross attention layers are added
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
# disable cache for now
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
return {
|
||||||
|
"config": config,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_config": decoder_config,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_token_type_ids": decoder_token_type_ids,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"decoder_sequence_labels": decoder_sequence_labels,
|
||||||
|
"decoder_token_labels": decoder_token_labels,
|
||||||
|
"decoder_choice_labels": decoder_choice_labels,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"labels": decoder_token_labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_bert2bert_summarization(self):
|
||||||
|
|
||||||
|
from transformers import EncoderDecoderModel
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.encoder.layer...`
|
||||||
|
(For Bert decoder, there is no issue, because `BertModel` is wrapped into `decoder` as `bert`)
|
||||||
|
model = TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16", from_pt=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# workaround to load from pt
|
||||||
|
_model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
_model.encoder.save_pretrained("./encoder")
|
||||||
|
_model.decoder.save_pretrained("./decoder")
|
||||||
|
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
"./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
|
||||||
|
)
|
||||||
|
model.config = _model.config
|
||||||
|
|
||||||
|
ARTICLE_STUDENTS = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
|
||||||
|
EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
|
||||||
|
|
||||||
|
input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf")
|
||||||
|
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
|
||||||
|
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||||
|
def get_pretrained_model(self):
|
||||||
|
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
|
||||||
|
|
||||||
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
encoder_model = TFRobertaModel(config, name="encoder")
|
||||||
|
decoder_model = TFRobertaForCausalLM(decoder_config, name="decoder")
|
||||||
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
model_tester_encoder = TFRobertaModelTester(self)
|
||||||
|
model_tester_decoder = TFRobertaModelTester(self)
|
||||||
|
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||||
|
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = encoder_config_and_inputs
|
||||||
|
(
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_token_type_ids,
|
||||||
|
decoder_input_mask,
|
||||||
|
decoder_sequence_labels,
|
||||||
|
decoder_token_labels,
|
||||||
|
decoder_choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = decoder_config_and_inputs
|
||||||
|
|
||||||
|
# make sure that cross attention layers are added
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
# disable cache for now
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
return {
|
||||||
|
"config": config,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"decoder_config": decoder_config,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_token_type_ids": decoder_token_type_ids,
|
||||||
|
"decoder_attention_mask": decoder_input_mask,
|
||||||
|
"decoder_sequence_labels": decoder_sequence_labels,
|
||||||
|
"decoder_token_labels": decoder_token_labels,
|
||||||
|
"decoder_choice_labels": decoder_choice_labels,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"labels": decoder_token_labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||||
|
def get_pretrained_model(self):
|
||||||
|
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("google/rembert", "google/rembert")
|
||||||
|
|
||||||
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
encoder_model = TFRemBertModel(config, name="encoder")
|
||||||
|
decoder_model = TFRemBertForCausalLM(decoder_config, name="decoder")
|
||||||
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
model_tester_encoder = TFRemBertModelTester(self)
|
||||||
|
model_tester_decoder = TFRemBertModelTester(self)
|
||||||
|
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||||
|
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = encoder_config_and_inputs
|
||||||
|
(
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_token_type_ids,
|
||||||
|
decoder_input_mask,
|
||||||
|
decoder_sequence_labels,
|
||||||
|
decoder_token_labels,
|
||||||
|
decoder_choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = decoder_config_and_inputs
|
||||||
|
|
||||||
|
# make sure that cross attention layers are added
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
# disable cache for now
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
return {
|
||||||
|
"config": config,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"decoder_config": decoder_config,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_token_type_ids": decoder_token_type_ids,
|
||||||
|
"decoder_attention_mask": decoder_input_mask,
|
||||||
|
"decoder_sequence_labels": decoder_sequence_labels,
|
||||||
|
"decoder_token_labels": decoder_token_labels,
|
||||||
|
"decoder_choice_labels": decoder_choice_labels,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"labels": decoder_token_labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFEncoderDecoderModelTest(unittest.TestCase):
|
||||||
|
def get_from_encoderdecoder_pretrained_model(self):
|
||||||
|
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
|
||||||
|
|
||||||
|
def get_decoder_config(self):
|
||||||
|
config = AutoConfig.from_pretrained("bert-base-cased")
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_encoderdecoder_model(self):
|
||||||
|
return TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
|
||||||
|
def get_encoder_decoder_models(self):
|
||||||
|
encoder_model = TFBertModel.from_pretrained("bert-base-cased", name="encoder")
|
||||||
|
decoder_model = TFBertLMHeadModel.from_pretrained(
|
||||||
|
"bert-base-cased", config=self.get_decoder_config(), name="decoder"
|
||||||
|
)
|
||||||
|
return {"encoder": encoder_model, "decoder": decoder_model}
|
||||||
|
|
||||||
|
def _check_configuration_tie(self, model):
|
||||||
|
assert id(model.decoder.config) == id(model.config.decoder)
|
||||||
|
assert id(model.encoder.config) == id(model.config.encoder)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_configuration_tie(self):
|
||||||
|
model = self.get_from_encoderdecoder_pretrained_model()
|
||||||
|
self._check_configuration_tie(model)
|
||||||
|
|
||||||
|
model = TFEncoderDecoderModel(**self.get_encoder_decoder_models())
|
||||||
|
self._check_configuration_tie(model)
|
||||||
|
|
||||||
|
# # This should be enabled once we upload the TF version of
|
||||||
|
# # "patrickvonplaten/bert2bert-cnn_dailymail-fp16" to the Hub.
|
||||||
|
# model = self.get_encoderdecoder_model()
|
||||||
|
# self._check_configuration_tie(model)
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||||
|
def get_encoder_decoder_config(self):
|
||||||
|
encoder_config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||||
|
decoder_config = AutoConfig.from_pretrained("bert-base-uncased", is_decoder=True, add_cross_attention=True)
|
||||||
|
return EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
||||||
|
|
||||||
|
def get_encoder_decoder_config_small(self):
|
||||||
|
encoder_config = AutoConfig.from_pretrained("hf-internal-testing/tiny-bert")
|
||||||
|
decoder_config = AutoConfig.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-bert", is_decoder=True, add_cross_attention=True
|
||||||
|
)
|
||||||
|
return EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
||||||
|
|
||||||
|
def test_encoder_decoder_save_load_from_encoder_decoder(self):
|
||||||
|
config = self.get_encoder_decoder_config_small()
|
||||||
|
|
||||||
|
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
|
||||||
|
encoder = TFBertModel(config.encoder)
|
||||||
|
encoder(encoder.dummy_inputs)
|
||||||
|
decoder = TFBertLMHeadModel(config.decoder)
|
||||||
|
decoder(decoder.dummy_inputs)
|
||||||
|
|
||||||
|
encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
input_ids = ids_tensor([13, 5], encoder.config.vocab_size)
|
||||||
|
decoder_input_ids = ids_tensor([13, 1], decoder.config.vocab_size)
|
||||||
|
|
||||||
|
logits_orig = encoder_decoder_orig(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
encoder_path = os.path.join(tmp_dirname, "encoder")
|
||||||
|
decoder_path = os.path.join(tmp_dirname, "decoder")
|
||||||
|
|
||||||
|
encoder.save_pretrained(encoder_path)
|
||||||
|
decoder.save_pretrained(decoder_path)
|
||||||
|
|
||||||
|
encoder_decoder = TFEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_path, decoder_path)
|
||||||
|
|
||||||
|
logits_1 = encoder_decoder(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
||||||
|
self.assertTrue(logits_orig.numpy().sum() - logits_1.numpy().sum() < 1e-3)
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(logits_1.numpy() - logits_orig.numpy()))
|
||||||
|
self.assertAlmostEqual(max_diff, 0.0, places=4)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
encoder_decoder.save_pretrained(tmp_dirname)
|
||||||
|
encoder_decoder = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
|
|
||||||
|
logits_2 = encoder_decoder(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(logits_2.numpy() - logits_orig.numpy()))
|
||||||
|
self.assertAlmostEqual(max_diff, 0.0, places=4)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
|
||||||
|
config = self.get_encoder_decoder_config_small()
|
||||||
|
|
||||||
|
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
|
||||||
|
encoder_pt = BertModel(config.encoder).to(torch_device).eval()
|
||||||
|
decoder_pt = BertLMHeadModel(config.decoder).to(torch_device).eval()
|
||||||
|
|
||||||
|
encoder_decoder_pt = EncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval()
|
||||||
|
|
||||||
|
input_ids = ids_tensor([13, 5], encoder_pt.config.vocab_size)
|
||||||
|
decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size)
|
||||||
|
|
||||||
|
pt_input_ids = torch.tensor(input_ids.numpy(), device=torch_device, dtype=torch.long)
|
||||||
|
pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long)
|
||||||
|
|
||||||
|
logits_pt = encoder_decoder_pt(input_ids=pt_input_ids, decoder_input_ids=pt_decoder_input_ids).logits
|
||||||
|
|
||||||
|
# PyTorch => TensorFlow
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
|
||||||
|
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
|
||||||
|
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
|
||||||
|
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
||||||
|
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
||||||
|
|
||||||
|
# TensorFlow => PyTorch
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
||||||
|
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
||||||
|
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_encoder_decoder_from_pretrained(self):
|
||||||
|
load_weight_prefix = "tf_encoder_decoder_model_1"
|
||||||
|
|
||||||
|
config = self.get_encoder_decoder_config()
|
||||||
|
encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
decoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
input_ids = encoder_tokenizer("who sings does he love me with reba", return_tensors="tf").input_ids
|
||||||
|
decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
|
||||||
|
# Since most of HF's models don't have pretrained cross-attention layers, they are randomly
|
||||||
|
# initialized even if we create models using `from_pretrained` method.
|
||||||
|
# For the tests, the decoder need to be a model with pretrained cross-attention layers.
|
||||||
|
# So we create pretrained models (without `load_weight_prefix`), save them, and later,
|
||||||
|
# we load them using `from_pretrained`.
|
||||||
|
# (we don't need to do this for encoder, but let's make the code more similar between encoder/decoder)
|
||||||
|
encoder = TFAutoModel.from_pretrained("bert-base-uncased", name="encoder")
|
||||||
|
# It's necessary to specify `add_cross_attention=True` here.
|
||||||
|
decoder = TFAutoModelForCausalLM.from_pretrained(
|
||||||
|
"bert-base-uncased", is_decoder=True, add_cross_attention=True, name="decoder"
|
||||||
|
)
|
||||||
|
pretrained_encoder_dir = os.path.join(tmp_dirname, "pretrained_encoder")
|
||||||
|
pretrained_decoder_dir = os.path.join(tmp_dirname, "pretrained_decoder")
|
||||||
|
encoder.save_pretrained(pretrained_encoder_dir)
|
||||||
|
decoder.save_pretrained(pretrained_decoder_dir)
|
||||||
|
del encoder
|
||||||
|
del decoder
|
||||||
|
|
||||||
|
enc_dec_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
pretrained_encoder_dir,
|
||||||
|
pretrained_decoder_dir,
|
||||||
|
)
|
||||||
|
# check that the from pretrained methods work
|
||||||
|
enc_dec_model.save_pretrained(tmp_dirname)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
|
|
||||||
|
output = enc_dec_model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
|
||||||
|
|
||||||
|
loss_pretrained = output.loss
|
||||||
|
del enc_dec_model
|
||||||
|
|
||||||
|
# Create the model using `__init__` with loaded ``pretrained`` encoder / decoder
|
||||||
|
encoder = TFAutoModel.from_pretrained(
|
||||||
|
pretrained_encoder_dir, load_weight_prefix=load_weight_prefix, name="encoder"
|
||||||
|
)
|
||||||
|
decoder = TFAutoModelForCausalLM.from_pretrained(
|
||||||
|
pretrained_decoder_dir, load_weight_prefix=load_weight_prefix, name="decoder"
|
||||||
|
)
|
||||||
|
enc_dec_model = TFEncoderDecoderModel(config=config, encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
output = enc_dec_model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
|
||||||
|
|
||||||
|
loss_init = output.loss
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(loss_pretrained - loss_init))
|
||||||
|
expected_diff = 0.0
|
||||||
|
|
||||||
|
self.assertAlmostEqual(max_diff, expected_diff, places=4)
|
||||||
@@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available
|
|||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -131,6 +131,33 @@ class TFRemBertModelTester:
|
|||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_model(
|
def create_and_check_model(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available
|
|||||||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -29,6 +29,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
from transformers.models.roberta.modeling_tf_roberta import (
|
from transformers.models.roberta.modeling_tf_roberta import (
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForMultipleChoice,
|
TFRobertaForMultipleChoice,
|
||||||
TFRobertaForQuestionAnswering,
|
TFRobertaForQuestionAnswering,
|
||||||
@@ -101,6 +102,33 @@ class TFRobertaModelTester:
|
|||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_roberta_model(
|
def create_and_check_roberta_model(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -115,6 +143,13 @@ class TFRobertaModelTester:
|
|||||||
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_roberta_for_causal_lm(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFRobertaForCausalLM(config=config)
|
||||||
|
result = model([input_ids, input_mask, token_type_ids])
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
def create_and_check_roberta_for_masked_lm(
|
def create_and_check_roberta_for_masked_lm(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -177,6 +212,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
TFRobertaModel,
|
TFRobertaModel,
|
||||||
|
TFRobertaForCausalLM,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
@@ -203,6 +239,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_causal_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_roberta_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_token_classification(self):
|
def test_for_token_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs)
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ def get_model_modules():
|
|||||||
"modeling_flax_utils",
|
"modeling_flax_utils",
|
||||||
"modeling_transfo_xl_utilities",
|
"modeling_transfo_xl_utilities",
|
||||||
"modeling_tf_auto",
|
"modeling_tf_auto",
|
||||||
|
"modeling_tf_encoder_decoder",
|
||||||
"modeling_tf_outputs",
|
"modeling_tf_outputs",
|
||||||
"modeling_tf_pytorch_utils",
|
"modeling_tf_pytorch_utils",
|
||||||
"modeling_tf_utils",
|
"modeling_tf_utils",
|
||||||
@@ -231,6 +232,7 @@ def get_model_test_files():
|
|||||||
"test_modeling_flax_encoder_decoder",
|
"test_modeling_flax_encoder_decoder",
|
||||||
"test_modeling_marian",
|
"test_modeling_marian",
|
||||||
"test_modeling_tf_common",
|
"test_modeling_tf_common",
|
||||||
|
"test_modeling_tf_encoder_decoder",
|
||||||
]
|
]
|
||||||
test_files = []
|
test_files = []
|
||||||
for filename in os.listdir(PATH_TO_TESTS):
|
for filename in os.listdir(PATH_TO_TESTS):
|
||||||
|
|||||||
Reference in New Issue
Block a user