[cleanup] TF T5 tests only init t5-base once. (#5410)
This commit is contained in:
@@ -478,7 +478,7 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||||||
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||||
|
|
||||||
|
|
||||||
class _NoLayerEmbedTokens(object):
|
class _NoLayerEmbedTokens:
|
||||||
"""
|
"""
|
||||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
|
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
|
||||||
class to avoid problem with weight restoring. Also it makes sure that the layer is
|
class to avoid problem with weight restoring. Also it makes sure that the layer is
|
||||||
@@ -655,7 +655,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
|
||||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
# extended_attention_mask = tf.math.equal(extended_attention_mask,
|
# extended_attention_mask = tf.math.equal(extended_attention_mask,
|
||||||
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
|
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
|
||||||
@@ -682,16 +682,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
encoder_extended_attention_mask = None
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
assert head_mask is None, "Head mask not supported"
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
||||||
|
|
||||||
present_key_value_states = ()
|
present_key_value_states = ()
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
@@ -1054,8 +1046,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
|||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
||||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided):
|
|
||||||
Classification loss (cross entropy).
|
|
||||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
|
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user