From df3f4d2aef70d409b4d9bea18d8821693bea3877 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 14 Dec 2020 18:47:00 +0100 Subject: [PATCH] Fix T5 and BART for TF (#9063) * Fix T5 for graphe compilation+execution * Fix BART * Fix import * Fix naming * fix attribute name * Oops * fix import * fix tests * fix tests * Update test * Add mising import * Address Patrick's comments * Style * Address Patrick's comment --- docs/source/internal/modeling_utils.rst | 2 - src/transformers/modeling_tf_pytorch_utils.py | 4 +- src/transformers/modeling_tf_utils.py | 33 +--- .../models/bart/modeling_tf_bart.py | 29 +--- src/transformers/models/t5/modeling_tf_t5.py | 151 ++++++++++-------- tests/test_modeling_tf_bart.py | 8 - tests/test_modeling_tf_common.py | 67 ++++++-- tests/test_modeling_tf_t5.py | 23 +-- 8 files changed, 151 insertions(+), 166 deletions(-) diff --git a/docs/source/internal/modeling_utils.rst b/docs/source/internal/modeling_utils.rst index 0b62d006bc..3d6d770dcd 100644 --- a/docs/source/internal/modeling_utils.rst +++ b/docs/source/internal/modeling_utils.rst @@ -91,8 +91,6 @@ TensorFlow loss functions TensorFlow Helper Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: transformers.modeling_tf_utils.cast_bool_to_primitive - .. autofunction:: transformers.modeling_tf_utils.get_initializer .. autofunction:: transformers.modeling_tf_utils.keras_serializable diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 761cf7d721..cdf979bf18 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -51,7 +51,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators - tf_name = tf_name[1:] # Remove level zero + # Some weights have a single name withtout "/" such as final_logits_bias in BART + if len(tf_name) > 1: + tf_name = tf_name[1:] # Remove level zero # When should we transpose the weights transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f7e982c1ca..b15bbe0690 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -354,7 +354,7 @@ def input_processing(func, config, input_ids, **kwargs): if isinstance(v, allowed_types) or v is None: output[k] = v else: - raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.") + raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") if isinstance(input_ids, (tuple, list)): for i, input in enumerate(input_ids): @@ -372,7 +372,7 @@ def input_processing(func, config, input_ids, **kwargs): output[parameter_names[i]] = input else: raise ValueError( - f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}." + f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." ) elif isinstance(input_ids, (dict, BatchEncoding)): if "inputs" in input_ids: @@ -399,13 +399,13 @@ def input_processing(func, config, input_ids, **kwargs): ) continue else: - raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.") + raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") else: if isinstance(input_ids, tf.Tensor) or input_ids is None: output[parameter_names[0]] = input_ids else: raise ValueError( - f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}." + f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}." ) for name in parameter_names: @@ -1366,31 +1366,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) -def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool: - """ - Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to - cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor. - - Args: - bool_variable (:obj:`Union[tf.Tensor, bool]`): - The variable to convert to a boolean. - default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`): - The default value to use in case the tensor has no numpy attribute. - - Returns: - :obj:`bool`: The converted value. - """ - # if bool variable is tensor and has numpy value - if tf.is_tensor(bool_variable): - if hasattr(bool_variable, "numpy"): - return bool(bool_variable.numpy()) - elif default_tensor_to_true: - return True - - # else variable is bool - return bool_variable - - class TFWrappedEmbeddings: """ this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 88af7f8336..4d731a923c 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -41,7 +41,6 @@ from ...modeling_tf_utils import ( TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, - cast_bool_to_primitive, input_processing, keras_serializable, shape_list, @@ -258,9 +257,11 @@ class TFEncoderLayer(tf.keras.layers.Layer): if self.normalize_before: x = self.self_attn_layer_norm(x) x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask) - assert shape_list(x) == shape_list( - residual - ), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}" + tf.debugging.assert_equal( + shape_list(x), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}", + ) x = self.dropout(x, training=training) x = residual + x if not self.normalize_before: @@ -295,9 +296,6 @@ class TFBartEncoder(tf.keras.layers.Layer): self.dropout = tf.keras.layers.Dropout(config.dropout) self.layerdrop = config.encoder_layerdrop - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings @@ -328,7 +326,6 @@ class TFBartEncoder(tf.keras.layers.Layer): if config.add_final_layer_norm else None ) - self.return_dict = config.return_dict def call( self, @@ -355,10 +352,6 @@ class TFBartEncoder(tf.keras.layers.Layer): - **all_attentions** (List[tf.Tensor]): Attention weights for each layer. During training might not be of length n_layers because of layer dropout. """ - output_attentions = output_attentions if output_attentions is not None else self.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states - return_dict = return_dict if return_dict is not None else self.return_dict - # check attention mask and invert if attention_mask is not None: assert ( @@ -546,9 +539,6 @@ class TFBartDecoder(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(config.dropout) - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.use_cache = config.use_cache self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm def call( @@ -565,14 +555,7 @@ class TFBartDecoder(tf.keras.layers.Layer): return_dict=None, training=False, ): - output_attentions = output_attentions if output_attentions is not None else self.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states - use_cache = use_cache if use_cache is not None else self.use_cache - return_dict = return_dict if return_dict is not None else self.config.return_dict - if use_cache: - assert not training, "Training + use cache are incompatible" # check attention mask and invert - use_cache = cast_bool_to_primitive(use_cache) if encoder_padding_mask is not None: encoder_padding_mask = invert_mask(encoder_padding_mask) @@ -1046,7 +1029,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): self.use_cache = config.use_cache # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. self.final_logits_bias = self.add_weight( - name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False ) def resize_token_embeddings(self, new_num_tokens): diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 60d1e95ebd..e238584d3a 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -32,12 +32,16 @@ from ...file_utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPast, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) from ...modeling_tf_utils import ( TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, - cast_bool_to_primitive, input_processing, keras_serializable, shape_list, @@ -311,7 +315,7 @@ class TFT5Attention(tf.keras.layers.Layer): ) # to cope with keras serialization - if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True: + if self.is_decoder and use_cache: present_key_value_state = (key_states, value_states) else: present_key_value_state = None @@ -594,6 +598,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): use_cache=None, output_attentions=None, output_hidden_states=None, + return_dict=None, training=False, **kwargs, ) -> Tuple: @@ -610,6 +615,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, training=training, kwargs_call=kwargs, ) @@ -713,10 +719,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): assert inputs["head_mask"] is None, "Head mask not supported" inputs["head_mask"] = [None] * self.num_hidden_layers - - present_key_value_states = () - all_hidden_states = () - all_attentions = () + present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None + all_hidden_states = () if inputs["output_hidden_states"] else None + all_attentions = () if inputs["output_attentions"] else None position_bias = None encoder_decoder_position_bias = None @@ -725,7 +730,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): if inputs["output_hidden_states"]: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, @@ -739,6 +743,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): output_attentions=inputs["output_attentions"], training=inputs["training"], ) + # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] @@ -747,10 +752,13 @@ class TFT5MainLayer(tf.keras.layers.Layer): # layer_outputs = hidden-states, past_key_values, (self-attention weights), # (self-attention position bias), (cross-attention position bias), (cross-attention weights), position_bias = layer_outputs[2] + if self.is_decoder and inputs["encoder_hidden_states"] is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3] + # append next layer key value states - present_key_value_states = present_key_value_states + (present_key_value_state,) + if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder: + present_key_value_states = present_key_value_states + (present_key_value_state,) if inputs["output_attentions"]: all_attentions = all_attentions + (layer_outputs[3],) @@ -762,15 +770,30 @@ class TFT5MainLayer(tf.keras.layers.Layer): if inputs["output_hidden_states"]: all_hidden_states = all_hidden_states + (hidden_states,) - outputs = (hidden_states,) - # need to check if is decoder here as well for special cases when using keras compile - if cast_bool_to_primitive(inputs["use_cache"], self.use_cache) is True and self.is_decoder: - outputs = outputs + (present_key_value_states,) - if inputs["output_hidden_states"]: - outputs = outputs + (all_hidden_states,) - if inputs["output_attentions"]: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) + if not inputs["return_dict"]: + outputs = (hidden_states,) + # need to check if is decoder here as well for special cases when using keras compile + if inputs["use_cache"] and self.is_decoder: + outputs = outputs + (present_key_value_states,) + if inputs["output_hidden_states"]: + outputs = outputs + (all_hidden_states,) + if inputs["output_attentions"]: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + if self.is_decoder: + return TFBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + else: + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) #################################################### @@ -1102,6 +1125,7 @@ class TFT5Model(TFT5PreTrainedModel): use_cache=False, output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], training=inputs["training"], ) @@ -1119,38 +1143,25 @@ class TFT5Model(TFT5PreTrainedModel): use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], training=inputs["training"], ) - past = ( - (inputs["encoder_outputs"], decoder_outputs[1]) - if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache) - else None - ) + past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None + if not inputs["return_dict"]: if past is not None: decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] return decoder_outputs + inputs["encoder_outputs"] - # This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch) - # TF refuses to compile anymore. - if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache): - decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:] - if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states): - inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:] - decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:] - if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions): - inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,) - decoder_outputs = decoder_outputs + (None,) - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs[0], + last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=past, - decoder_hidden_states=decoder_outputs[2], - decoder_attentions=decoder_outputs[3], - encoder_last_hidden_state=inputs["encoder_outputs"][0], - encoder_hidden_states=inputs["encoder_outputs"][1], - encoder_attentions=inputs["encoder_outputs"][2], + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, + encoder_hidden_states=inputs["encoder_outputs"].hidden_states, + encoder_attentions=inputs["encoder_outputs"].attentions, ) @@ -1280,6 +1291,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling head_mask=inputs["head_mask"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], training=inputs["training"], ) @@ -1313,6 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], training=inputs["training"], ) @@ -1327,37 +1340,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) - past = ( - (inputs["encoder_outputs"], decoder_outputs[1]) - if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache) - else None - ) + past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None if not inputs["return_dict"]: if past is not None: decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] return ((loss,) + output) if loss is not None else output - # This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch) - # TF refuses to compile anymore. - if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache): - decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:] - if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states): - inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:] - decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:] - if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions): - inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,) - decoder_outputs = decoder_outputs + (None,) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif isinstance(inputs["encoder_outputs"], tuple): + last_hidden_state = inputs["encoder_outputs"][0] + hidden_states = None + attentions = None + idx = 0 + if inputs["output_hidden_states"]: + idx += 1 + hidden_states = inputs["encoder_outputs"][idx] + if inputs["output_attentions"]: + idx += 1 + attentions = inputs["encoder_outputs"][idx] + + inputs["encoder_outputs"] = TFBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) return TFSeq2SeqLMOutput( loss=loss, logits=logits, past_key_values=past, - decoder_hidden_states=decoder_outputs[2], - decoder_attentions=decoder_outputs[3], - encoder_last_hidden_state=inputs["encoder_outputs"][0], - encoder_hidden_states=inputs["encoder_outputs"][1], - encoder_attentions=inputs["encoder_outputs"][2], + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, + encoder_hidden_states=inputs["encoder_outputs"].hidden_states, + encoder_attentions=inputs["encoder_outputs"].attentions, ) def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs): @@ -1498,19 +1515,15 @@ class TFT5EncoderModel(TFT5PreTrainedModel): use_cache=False, output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], training=inputs["training"], ) if not inputs["return_dict"]: return encoder_outputs - if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states): - encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:] - if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions): - encoder_outputs = encoder_outputs + (None,) - return TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1], - attentions=encoder_outputs[2], + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, ) diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 99c3d03eca..53d238ed9d 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -118,14 +118,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): # inputs_embeds not supported pass - def test_saved_model_with_hidden_states_output(self): - # Should be uncommented during patrick TF refactor - pass - - def test_saved_model_with_attentions_output(self): - # Should be uncommented during patrick TF refactor - pass - def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 5aa1e78e17..2a65b1e183 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -171,6 +171,11 @@ class TFModelTesterMixin: for model_class in self.all_model_classes: class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + # A saved model is always executed in graph mode, since we merged the PR #8777 + # the booleans in graph mode are always the ones in the config, then we update + # the use_cache property if it exists in order to have similar booleans with the inputs + if "use_cache" in class_inputs_dict: + config.use_cache = class_inputs_dict.pop("use_cache") model = model_class(config) num_out = len(model(class_inputs_dict)) model._saved_model_inputs_spec = None @@ -207,6 +212,11 @@ class TFModelTesterMixin: for model_class in self.all_model_classes: class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + # A saved model is always executed in graph mode, since we merged the PR #8777 + # the booleans in graph mode are always the ones in the config, then we update + # the use_cache property if it exists in order to have similar booleans with the inputs + if "use_cache" in class_inputs_dict: + config.use_cache = class_inputs_dict.pop("use_cache") model = model_class(config) num_out = len(model(class_inputs_dict)) model._saved_model_inputs_spec = None @@ -249,10 +259,11 @@ class TFModelTesterMixin: if "T5" in main_layer_class.__name__: # Take the same values than in TFT5ModelTester for this shared layer shared = TFSharedEmbeddings(99, 32, name="shared") - config.use_cache = False + config.use_cache = inputs_dict.pop("use_cache", None) main_layer = main_layer_class(config, embed_tokens=shared) else: main_layer = main_layer_class(config) + symbolic_inputs = { name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() } @@ -321,10 +332,13 @@ class TFModelTesterMixin: # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() - pt_inputs_dict = dict( - (name, torch.from_numpy(key.numpy()).to(torch.long)) - for name, key in self._prepare_for_class(inputs_dict, model_class).items() - ) + pt_inputs_dict = {} + for name, key in self._prepare_for_class(inputs_dict, model_class).items(): + if type(key) == bool: + pt_inputs_dict[name] = key + else: + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) + # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") @@ -358,10 +372,13 @@ class TFModelTesterMixin: # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() - pt_inputs_dict = dict( - (name, torch.from_numpy(key.numpy()).to(torch.long)) - for name, key in self._prepare_for_class(inputs_dict, model_class).items() - ) + pt_inputs_dict = {} + for name, key in self._prepare_for_class(inputs_dict, model_class).items(): + if type(key) == bool: + key = np.array(key, dtype=bool) + pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) + else: + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") @@ -574,13 +591,29 @@ class TFModelTesterMixin: self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 ) - hidden_states = outputs[-1] - self.assertEqual(config.output_attentions, False) - self.assertEqual(len(hidden_states), expected_num_layers) - self.assertListEqual( - list(hidden_states[0].shape[-2:]), - [self.model_tester.seq_length, self.model_tester.hidden_size], - ) + if model.config.is_encoder_decoder: + encoder_hidden_states = outputs.encoder_hidden_states + decoder_hidden_states = outputs.decoder_hidden_states + + self.assertEqual(config.output_attentions, False) + self.assertEqual(len(encoder_hidden_states), expected_num_layers) + self.assertListEqual( + list(encoder_hidden_states[0].shape[-2:]), + [self.model_tester.seq_length, self.model_tester.hidden_size], + ) + self.assertEqual(len(decoder_hidden_states), expected_num_layers) + self.assertListEqual( + list(decoder_hidden_states[0].shape[-2:]), + [self.model_tester.seq_length, self.model_tester.hidden_size], + ) + else: + hidden_states = outputs.hidden_states + self.assertEqual(config.output_attentions, False) + self.assertEqual(len(hidden_states), expected_num_layers) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.seq_length, self.model_tester.hidden_size], + ) for model_class in self.all_model_classes: inputs_dict["output_hidden_states"] = True @@ -796,7 +829,7 @@ class TFModelTesterMixin: def test_lm_head_model_random_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] + input_ids = inputs_dict["input_ids"] for model_class in self.all_generative_model_classes: model = model_class(config) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 64bb41bef1..4f4b732fc8 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -133,8 +133,6 @@ class TFT5ModelTester: self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) - output, past_key_values = outputs - # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -142,7 +140,7 @@ class TFT5ModelTester: next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) output_from_no_past = model(next_input_ids)[0] - output_from_past = model(next_tokens, past_key_values=past_key_values)[0] + output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0] # select random slice random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) @@ -164,7 +162,7 @@ class TFT5ModelTester: attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) # first forward pass - _, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -187,7 +185,7 @@ class TFT5ModelTester: # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0] + output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item() @@ -208,8 +206,6 @@ class TFT5ModelTester: # first forward pass outputs = model(input_ids, use_cache=True) - output, past_key_values = outputs - # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) @@ -217,7 +213,7 @@ class TFT5ModelTester: next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) output_from_no_past = model(next_input_ids)[0] - output_from_past = model(next_tokens, past_key_values=past_key_values)[0] + output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0] self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) @@ -236,7 +232,7 @@ class TFT5ModelTester: "input_ids": input_ids, "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, - "use_cache": tf.convert_to_tensor([False]), + "use_cache": False, } return config, inputs_dict @@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): model = TFT5Model.from_pretrained("t5-small") self.assertIsNotNone(model) - @slow - def test_saved_model_with_attentions_output(self): - pass - - @slow - def test_saved_model_with_hidden_states_output(self): - pass - class TFT5EncoderOnlyModelTester: def __init__( @@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester: class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): + is_encoder_decoder = False all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () def setUp(self):