[cleanup] TF T5 tests only init t5-base once. (#5410)
This commit is contained in:
@@ -478,7 +478,7 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
|
||||
|
||||
class _NoLayerEmbedTokens(object):
|
||||
class _NoLayerEmbedTokens:
|
||||
"""
|
||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
|
||||
class to avoid problem with weight restoring. Also it makes sure that the layer is
|
||||
@@ -655,7 +655,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||
# extended_attention_mask = tf.math.equal(extended_attention_mask,
|
||||
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
|
||||
@@ -682,16 +682,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
assert head_mask is None, "Head mask not supported"
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
|
||||
present_key_value_states = ()
|
||||
all_hidden_states = ()
|
||||
@@ -1054,8 +1046,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
||||
r"""
|
||||
Returns:
|
||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided):
|
||||
Classification loss (cross entropy).
|
||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user