Fix T5 and BART for TF (#9063)
* Fix T5 for graphe compilation+execution * Fix BART * Fix import * Fix naming * fix attribute name * Oops * fix import * fix tests * fix tests * Update test * Add mising import * Address Patrick's comments * Style * Address Patrick's comment
This commit is contained in:
@@ -91,8 +91,6 @@ TensorFlow loss functions
|
|||||||
TensorFlow Helper Functions
|
TensorFlow Helper Functions
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autofunction:: transformers.modeling_tf_utils.cast_bool_to_primitive
|
|
||||||
|
|
||||||
.. autofunction:: transformers.modeling_tf_utils.get_initializer
|
.. autofunction:: transformers.modeling_tf_utils.get_initializer
|
||||||
|
|
||||||
.. autofunction:: transformers.modeling_tf_utils.keras_serializable
|
.. autofunction:: transformers.modeling_tf_utils.keras_serializable
|
||||||
|
|||||||
@@ -51,7 +51,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
|
|||||||
) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||||
tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
|
tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
|
||||||
tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
||||||
tf_name = tf_name[1:] # Remove level zero
|
# Some weights have a single name withtout "/" such as final_logits_bias in BART
|
||||||
|
if len(tf_name) > 1:
|
||||||
|
tf_name = tf_name[1:] # Remove level zero
|
||||||
|
|
||||||
# When should we transpose the weights
|
# When should we transpose the weights
|
||||||
transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name)
|
transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name)
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
if isinstance(v, allowed_types) or v is None:
|
if isinstance(v, allowed_types) or v is None:
|
||||||
output[k] = v
|
output[k] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
|
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||||
|
|
||||||
if isinstance(input_ids, (tuple, list)):
|
if isinstance(input_ids, (tuple, list)):
|
||||||
for i, input in enumerate(input_ids):
|
for i, input in enumerate(input_ids):
|
||||||
@@ -372,7 +372,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
output[parameter_names[i]] = input
|
output[parameter_names[i]] = input
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
|
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
||||||
)
|
)
|
||||||
elif isinstance(input_ids, (dict, BatchEncoding)):
|
elif isinstance(input_ids, (dict, BatchEncoding)):
|
||||||
if "inputs" in input_ids:
|
if "inputs" in input_ids:
|
||||||
@@ -399,13 +399,13 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
|
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||||
else:
|
else:
|
||||||
if isinstance(input_ids, tf.Tensor) or input_ids is None:
|
if isinstance(input_ids, tf.Tensor) or input_ids is None:
|
||||||
output[parameter_names[0]] = input_ids
|
output[parameter_names[0]] = input_ids
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
|
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for name in parameter_names:
|
for name in parameter_names:
|
||||||
@@ -1366,31 +1366,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate
|
|||||||
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
|
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
|
||||||
|
|
||||||
|
|
||||||
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
|
|
||||||
"""
|
|
||||||
Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
|
|
||||||
cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bool_variable (:obj:`Union[tf.Tensor, bool]`):
|
|
||||||
The variable to convert to a boolean.
|
|
||||||
default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
|
|
||||||
The default value to use in case the tensor has no numpy attribute.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`bool`: The converted value.
|
|
||||||
"""
|
|
||||||
# if bool variable is tensor and has numpy value
|
|
||||||
if tf.is_tensor(bool_variable):
|
|
||||||
if hasattr(bool_variable, "numpy"):
|
|
||||||
return bool(bool_variable.numpy())
|
|
||||||
elif default_tensor_to_true:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# else variable is bool
|
|
||||||
return bool_variable
|
|
||||||
|
|
||||||
|
|
||||||
class TFWrappedEmbeddings:
|
class TFWrappedEmbeddings:
|
||||||
"""
|
"""
|
||||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
|
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from ...modeling_tf_utils import (
|
|||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
|
||||||
input_processing,
|
input_processing,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -258,9 +257,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
|
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
|
||||||
assert shape_list(x) == shape_list(
|
tf.debugging.assert_equal(
|
||||||
residual
|
shape_list(x),
|
||||||
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}"
|
shape_list(residual),
|
||||||
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}",
|
||||||
|
)
|
||||||
x = self.dropout(x, training=training)
|
x = self.dropout(x, training=training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
@@ -295,9 +296,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.layerdrop = config.encoder_layerdrop
|
self.layerdrop = config.encoder_layerdrop
|
||||||
self.output_hidden_states = config.output_hidden_states
|
|
||||||
self.output_attentions = config.output_attentions
|
|
||||||
|
|
||||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.max_source_positions = config.max_position_embeddings
|
self.max_source_positions = config.max_position_embeddings
|
||||||
@@ -328,7 +326,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
if config.add_final_layer_norm
|
if config.add_final_layer_norm
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.return_dict = config.return_dict
|
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -355,10 +352,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
|
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
|
||||||
During training might not be of length n_layers because of layer dropout.
|
During training might not be of length n_layers because of layer dropout.
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
|
||||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
|
||||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
|
||||||
|
|
||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
@@ -546,9 +539,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
|
||||||
self.output_attentions = config.output_attentions
|
|
||||||
self.use_cache = config.use_cache
|
|
||||||
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
|
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -565,14 +555,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
return_dict=None,
|
return_dict=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
|
||||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
|
||||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
||||||
if use_cache:
|
|
||||||
assert not training, "Training + use cache are incompatible"
|
|
||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
use_cache = cast_bool_to_primitive(use_cache)
|
|
||||||
if encoder_padding_mask is not None:
|
if encoder_padding_mask is not None:
|
||||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||||
|
|
||||||
@@ -1046,7 +1029,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||||||
self.use_cache = config.use_cache
|
self.use_cache = config.use_cache
|
||||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||||
self.final_logits_bias = self.add_weight(
|
self.final_logits_bias = self.add_weight(
|
||||||
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def resize_token_embeddings(self, new_num_tokens):
|
def resize_token_embeddings(self, new_num_tokens):
|
||||||
|
|||||||
@@ -32,12 +32,16 @@ from ...file_utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
|
from ...modeling_tf_outputs import (
|
||||||
|
TFBaseModelOutput,
|
||||||
|
TFBaseModelOutputWithPast,
|
||||||
|
TFSeq2SeqLMOutput,
|
||||||
|
TFSeq2SeqModelOutput,
|
||||||
|
)
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
|
||||||
input_processing,
|
input_processing,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -311,7 +315,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# to cope with keras serialization
|
# to cope with keras serialization
|
||||||
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
|
if self.is_decoder and use_cache:
|
||||||
present_key_value_state = (key_states, value_states)
|
present_key_value_state = (key_states, value_states)
|
||||||
else:
|
else:
|
||||||
present_key_value_state = None
|
present_key_value_state = None
|
||||||
@@ -594,6 +598,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
training=False,
|
training=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
@@ -610,6 +615,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
@@ -713,10 +719,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
assert inputs["head_mask"] is None, "Head mask not supported"
|
assert inputs["head_mask"] is None, "Head mask not supported"
|
||||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||||
|
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
||||||
present_key_value_states = ()
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_hidden_states = ()
|
all_attentions = () if inputs["output_attentions"] else None
|
||||||
all_attentions = ()
|
|
||||||
position_bias = None
|
position_bias = None
|
||||||
encoder_decoder_position_bias = None
|
encoder_decoder_position_bias = None
|
||||||
|
|
||||||
@@ -725,7 +730,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
@@ -739,6 +743,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# layer_outputs is a tuple with:
|
# layer_outputs is a tuple with:
|
||||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||||
@@ -747,10 +752,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
|
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
|
||||||
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
|
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
|
||||||
position_bias = layer_outputs[2]
|
position_bias = layer_outputs[2]
|
||||||
|
|
||||||
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
|
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
|
||||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3]
|
||||||
|
|
||||||
# append next layer key value states
|
# append next layer key value states
|
||||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder:
|
||||||
|
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_attentions = all_attentions + (layer_outputs[3],)
|
all_attentions = all_attentions + (layer_outputs[3],)
|
||||||
@@ -762,15 +770,30 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
if not inputs["return_dict"]:
|
||||||
# need to check if is decoder here as well for special cases when using keras compile
|
outputs = (hidden_states,)
|
||||||
if cast_bool_to_primitive(inputs["use_cache"], self.use_cache) is True and self.is_decoder:
|
# need to check if is decoder here as well for special cases when using keras compile
|
||||||
outputs = outputs + (present_key_value_states,)
|
if inputs["use_cache"] and self.is_decoder:
|
||||||
if inputs["output_hidden_states"]:
|
outputs = outputs + (present_key_value_states,)
|
||||||
outputs = outputs + (all_hidden_states,)
|
if inputs["output_hidden_states"]:
|
||||||
if inputs["output_attentions"]:
|
outputs = outputs + (all_hidden_states,)
|
||||||
outputs = outputs + (all_attentions,)
|
if inputs["output_attentions"]:
|
||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
outputs = outputs + (all_attentions,)
|
||||||
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
return TFBaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=present_key_value_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TFBaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
@@ -1102,6 +1125,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1119,38 +1143,25 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
past = (
|
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||||
(inputs["encoder_outputs"], decoder_outputs[1])
|
|
||||||
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
if past is not None:
|
if past is not None:
|
||||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||||
return decoder_outputs + inputs["encoder_outputs"]
|
return decoder_outputs + inputs["encoder_outputs"]
|
||||||
|
|
||||||
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
|
||||||
# TF refuses to compile anymore.
|
|
||||||
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
|
||||||
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
|
|
||||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
|
||||||
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
|
|
||||||
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
|
|
||||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
|
||||||
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
|
|
||||||
decoder_outputs = decoder_outputs + (None,)
|
|
||||||
|
|
||||||
return TFSeq2SeqModelOutput(
|
return TFSeq2SeqModelOutput(
|
||||||
last_hidden_state=decoder_outputs[0],
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs[2],
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs[3],
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"][0],
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"][1],
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"][2],
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1280,6 +1291,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1313,6 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1327,37 +1340,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
|
|
||||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||||
|
|
||||||
past = (
|
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||||
(inputs["encoder_outputs"], decoder_outputs[1])
|
|
||||||
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
if past is not None:
|
if past is not None:
|
||||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||||
# TF refuses to compile anymore.
|
elif isinstance(inputs["encoder_outputs"], tuple):
|
||||||
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
last_hidden_state = inputs["encoder_outputs"][0]
|
||||||
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
|
hidden_states = None
|
||||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
attentions = None
|
||||||
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
|
idx = 0
|
||||||
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
|
if inputs["output_hidden_states"]:
|
||||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
idx += 1
|
||||||
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
|
hidden_states = inputs["encoder_outputs"][idx]
|
||||||
decoder_outputs = decoder_outputs + (None,)
|
if inputs["output_attentions"]:
|
||||||
|
idx += 1
|
||||||
|
attentions = inputs["encoder_outputs"][idx]
|
||||||
|
|
||||||
|
inputs["encoder_outputs"] = TFBaseModelOutput(
|
||||||
|
last_hidden_state=last_hidden_state,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attentions=attentions,
|
||||||
|
)
|
||||||
|
|
||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs[2],
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs[3],
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"][0],
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"][1],
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"][2],
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
|
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
|
||||||
@@ -1498,19 +1515,15 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
|
return_dict=inputs["return_dict"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
|
||||||
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
|
|
||||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
|
||||||
encoder_outputs = encoder_outputs + (None,)
|
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutput(
|
||||||
last_hidden_state=encoder_outputs[0],
|
last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
hidden_states=encoder_outputs[1],
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs[2],
|
attentions=encoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -118,14 +118,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# inputs_embeds not supported
|
# inputs_embeds not supported
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_saved_model_with_hidden_states_output(self):
|
|
||||||
# Should be uncommented during patrick TF refactor
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_saved_model_with_attentions_output(self):
|
|
||||||
# Should be uncommented during patrick TF refactor
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -171,6 +171,11 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
# A saved model is always executed in graph mode, since we merged the PR #8777
|
||||||
|
# the booleans in graph mode are always the ones in the config, then we update
|
||||||
|
# the use_cache property if it exists in order to have similar booleans with the inputs
|
||||||
|
if "use_cache" in class_inputs_dict:
|
||||||
|
config.use_cache = class_inputs_dict.pop("use_cache")
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
num_out = len(model(class_inputs_dict))
|
num_out = len(model(class_inputs_dict))
|
||||||
model._saved_model_inputs_spec = None
|
model._saved_model_inputs_spec = None
|
||||||
@@ -207,6 +212,11 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
# A saved model is always executed in graph mode, since we merged the PR #8777
|
||||||
|
# the booleans in graph mode are always the ones in the config, then we update
|
||||||
|
# the use_cache property if it exists in order to have similar booleans with the inputs
|
||||||
|
if "use_cache" in class_inputs_dict:
|
||||||
|
config.use_cache = class_inputs_dict.pop("use_cache")
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
num_out = len(model(class_inputs_dict))
|
num_out = len(model(class_inputs_dict))
|
||||||
model._saved_model_inputs_spec = None
|
model._saved_model_inputs_spec = None
|
||||||
@@ -249,10 +259,11 @@ class TFModelTesterMixin:
|
|||||||
if "T5" in main_layer_class.__name__:
|
if "T5" in main_layer_class.__name__:
|
||||||
# Take the same values than in TFT5ModelTester for this shared layer
|
# Take the same values than in TFT5ModelTester for this shared layer
|
||||||
shared = TFSharedEmbeddings(99, 32, name="shared")
|
shared = TFSharedEmbeddings(99, 32, name="shared")
|
||||||
config.use_cache = False
|
config.use_cache = inputs_dict.pop("use_cache", None)
|
||||||
main_layer = main_layer_class(config, embed_tokens=shared)
|
main_layer = main_layer_class(config, embed_tokens=shared)
|
||||||
else:
|
else:
|
||||||
main_layer = main_layer_class(config)
|
main_layer = main_layer_class(config)
|
||||||
|
|
||||||
symbolic_inputs = {
|
symbolic_inputs = {
|
||||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||||
}
|
}
|
||||||
@@ -321,10 +332,13 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
pt_model.eval()
|
pt_model.eval()
|
||||||
pt_inputs_dict = dict(
|
pt_inputs_dict = {}
|
||||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
if type(key) == bool:
|
||||||
)
|
pt_inputs_dict[name] = key
|
||||||
|
else:
|
||||||
|
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||||
|
|
||||||
# need to rename encoder-decoder "inputs" for PyTorch
|
# need to rename encoder-decoder "inputs" for PyTorch
|
||||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||||
@@ -358,10 +372,13 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
pt_model.eval()
|
pt_model.eval()
|
||||||
pt_inputs_dict = dict(
|
pt_inputs_dict = {}
|
||||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
if type(key) == bool:
|
||||||
)
|
key = np.array(key, dtype=bool)
|
||||||
|
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
|
||||||
|
else:
|
||||||
|
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||||
# need to rename encoder-decoder "inputs" for PyTorch
|
# need to rename encoder-decoder "inputs" for PyTorch
|
||||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||||
@@ -574,13 +591,29 @@ class TFModelTesterMixin:
|
|||||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[-1]
|
if model.config.is_encoder_decoder:
|
||||||
self.assertEqual(config.output_attentions, False)
|
encoder_hidden_states = outputs.encoder_hidden_states
|
||||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
decoder_hidden_states = outputs.decoder_hidden_states
|
||||||
self.assertListEqual(
|
|
||||||
list(hidden_states[0].shape[-2:]),
|
self.assertEqual(config.output_attentions, False)
|
||||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
self.assertEqual(len(encoder_hidden_states), expected_num_layers)
|
||||||
)
|
self.assertListEqual(
|
||||||
|
list(encoder_hidden_states[0].shape[-2:]),
|
||||||
|
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(decoder_hidden_states), expected_num_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(decoder_hidden_states[0].shape[-2:]),
|
||||||
|
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs.hidden_states
|
||||||
|
self.assertEqual(config.output_attentions, False)
|
||||||
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(hidden_states[0].shape[-2:]),
|
||||||
|
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
inputs_dict["output_hidden_states"] = True
|
inputs_dict["output_hidden_states"] = True
|
||||||
@@ -796,7 +829,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
def test_lm_head_model_random_beam_search_generate(self):
|
def test_lm_head_model_random_beam_search_generate(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|||||||
@@ -133,8 +133,6 @@ class TFT5ModelTester:
|
|||||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
output, past_key_values = outputs
|
|
||||||
|
|
||||||
# create hypothetical next token and extent to next_input_ids
|
# create hypothetical next token and extent to next_input_ids
|
||||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
@@ -142,7 +140,7 @@ class TFT5ModelTester:
|
|||||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||||
|
|
||||||
output_from_no_past = model(next_input_ids)[0]
|
output_from_no_past = model(next_input_ids)[0]
|
||||||
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
|
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||||
|
|
||||||
# select random slice
|
# select random slice
|
||||||
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||||
@@ -164,7 +162,7 @@ class TFT5ModelTester:
|
|||||||
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
|
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
_, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||||
|
|
||||||
# create hypothetical next token and extent to next_input_ids
|
# create hypothetical next token and extent to next_input_ids
|
||||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
@@ -187,7 +185,7 @@ class TFT5ModelTester:
|
|||||||
|
|
||||||
# get two different outputs
|
# get two different outputs
|
||||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||||
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0]
|
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
|
||||||
|
|
||||||
# select random slice
|
# select random slice
|
||||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||||
@@ -208,8 +206,6 @@ class TFT5ModelTester:
|
|||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, use_cache=True)
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs
|
|
||||||
|
|
||||||
# create hypothetical next token and extent to next_input_ids
|
# create hypothetical next token and extent to next_input_ids
|
||||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
|
||||||
@@ -217,7 +213,7 @@ class TFT5ModelTester:
|
|||||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||||
|
|
||||||
output_from_no_past = model(next_input_ids)[0]
|
output_from_no_past = model(next_input_ids)[0]
|
||||||
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
|
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||||
|
|
||||||
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||||
|
|
||||||
@@ -236,7 +232,7 @@ class TFT5ModelTester:
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"decoder_attention_mask": input_mask,
|
"decoder_attention_mask": input_mask,
|
||||||
"use_cache": tf.convert_to_tensor([False]),
|
"use_cache": False,
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFT5Model.from_pretrained("t5-small")
|
model = TFT5Model.from_pretrained("t5-small")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_saved_model_with_attentions_output(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_saved_model_with_hidden_states_output(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TFT5EncoderOnlyModelTester:
|
class TFT5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester:
|
|||||||
|
|
||||||
|
|
||||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
is_encoder_decoder = False
|
||||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user