From 8635407bc724c45142c1f91dbc9ef3ea681e1a56 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 25 Feb 2022 17:11:46 +0100 Subject: [PATCH] Fix tf.concatenate + test past_key_values for TF models (#15774) * fix wrong method name tf.concatenate * add tests related to causal LM / decoder * make style and quality * clean-up * Fix TFBertModel's extended_attention_mask when past_key_values is provided * Fix tests * fix copies * More tf.int8 -> tf.int32 in TF test template * clean-up * Update TF test template * revert the previous commit + update the TF test template * Fix TF template extended_attention_mask when past_key_values is provided * Fix some styles manually * clean-up * Fix ValueError: too many values to unpack in the test * Fix more: too many values to unpack in the test * Add a comment for extended_attention_mask when there is past_key_values * Fix TFElectra extended_attention_mask when past_key_values is provided * Add tests to other TF models * Fix for TF Electra test: add prepare_config_and_inputs_for_decoder * Fix not passing training arg to lm_head in TFRobertaForCausalLM * Fix tests (with past) for TF Roberta * add testing for pask_key_values for TFElectra model Co-authored-by: ydshieh --- .../models/bert/modeling_tf_bert.py | 7 +- .../models/electra/modeling_tf_electra.py | 6 +- .../models/layoutlm/modeling_tf_layoutlm.py | 4 +- .../models/rembert/modeling_tf_rembert.py | 7 +- .../models/roberta/modeling_tf_roberta.py | 9 +- .../models/tapas/modeling_tf_tapas.py | 4 +- ...tf_{{cookiecutter.lowercase_modelname}}.py | 22 +- ...tf_{{cookiecutter.lowercase_modelname}}.py | 358 ++++++++++++++- tests/bert/test_modeling_tf_bert.py | 388 +++++++++++++++- tests/electra/test_modeling_tf_electra.py | 366 ++++++++++++++- tests/rembert/test_modeling_tf_rembert.py | 349 ++++++++++++++- tests/roberta/test_modeling_tf_roberta.py | 415 +++++++++++++++++- 12 files changed, 1851 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index bf5ddb365b..18dd4e754e 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -274,8 +274,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer): elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) @@ -817,6 +817,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) + if inputs["past_key_values"][0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index 68c639de91..5370c40fd4 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -140,8 +140,8 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) @@ -673,6 +673,8 @@ class TFElectraMainLayer(tf.keras.layers.Layer): extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) + if past_key_values_length > 0: + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) diff --git a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py index 6f30883500..df8a49df09 100644 --- a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py @@ -252,8 +252,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py index 24a6387cd7..c7b65fe3a1 100644 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ b/src/transformers/models/rembert/modeling_tf_rembert.py @@ -212,8 +212,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) @@ -740,6 +740,9 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) + if inputs["past_key_values"][0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index b74863fb20..dd45b6fd2d 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -261,8 +261,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) @@ -704,6 +704,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) + if inputs["past_key_values"][0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) @@ -1305,7 +1308,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ) sequence_output = outputs[0] - logits = self.lm_head(hidden_states=sequence_output) + logits = self.lm_head(hidden_states=sequence_output, training=inputs["training"]) loss = None if inputs["labels"] is not None: diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py index 46baba2627..94892d6206 100644 --- a/src/transformers/models/tapas/modeling_tf_tapas.py +++ b/src/transformers/models/tapas/modeling_tf_tapas.py @@ -317,8 +317,8 @@ class TFTapasSelfAttention(tf.keras.layers.Layer): elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 3dbe073e68..726438c17d 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -32,7 +32,6 @@ from ...file_utils import ( ) from ...modeling_tf_outputs import ( TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, TFCausalLMOutputWithCrossAttentions, TFMaskedLMOutput, TFMultipleChoiceModelOutput, @@ -216,8 +215,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2) - value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) else: key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) @@ -654,7 +653,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): return_dict: Optional[bool] = None, training: bool = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: inputs = input_processing( func=self.call, config=self.config, @@ -734,6 +733,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) + if inputs["past_key_values"][0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) @@ -945,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + output_type=TFBaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def call( @@ -965,7 +967,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod return_dict: Optional[bool] = None, training: Optional[bool] = False, **kwargs, - ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: r""" encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -1024,10 +1026,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod return outputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output def serving_output( - self, output: TFBaseModelOutputWithPoolingAndCrossAttentions - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + self, output: TFBaseModelOutputWithPastAndCrossAttentions + ) -> TFBaseModelOutputWithPastAndCrossAttentions: output_cache = self.config.use_cache and self.config.is_decoder pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None @@ -1036,9 +1037,8 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod if not (self.config.output_attentions and self.config.add_cross_attention): cross_attns = None - return TFBaseModelOutputWithPoolingAndCrossAttentions( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=output.last_hidden_state, - pooler_output=output.pooler_output, past_key_values=pkv, hidden_states=hs, attentions=attns, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 390298a115..d2875f232c 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -163,10 +163,59 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - def create_and_check_lm_head( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + def create_and_check_causal_lm_base_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.is_decoder = True + + model = TF{{cookiecutter.camelcase_modelname}}Model(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TF{{cookiecutter.camelcase_modelname}}Model(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + # Also check the case where encoder outputs are not passed + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_causal_lm_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.is_decoder = True + model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config) inputs = { "input_ids": input_ids, @@ -178,6 +227,260 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester: list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] ) + def create_and_check_causal_lm_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + prediction_scores = result["logits"] + self.parent.assertListEqual( + list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_causal_lm_model_past( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config) + + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_with_attn_mask( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config) + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + past_key_values = outputs.past_key_values + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat( + [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], + axis=1, + ) + + output_from_no_past = model( + next_input_ids, + attention_mask=attn_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + encoder_hidden_states = encoder_hidden_states[:1, :, :] + encoder_attention_mask = encoder_attention_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -290,16 +593,59 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte self.config_tester.run_common_tests() def test_model(self): + """Test the base model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_causal_lm_base_model(self): + """Test the base model of the causal LM model + + is_deocder=True, no cross_attention, no encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs) + + def test_model_as_decoder(self): + """Test the base model as a decoder (of an encoder-decoder architecture) + + is_deocder=True + cross_attention + pass encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_causal_lm(self): + """Test the causal LM model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_lm_head(*config_and_inputs) + self.model_tester.create_and_check_causal_lm_model(*config_and_inputs) + + def test_causal_lm_model_as_decoder(self): + """Test the causal LM model as a decoder""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs) + + def test_causal_lm_model_past(self): + """Test causal LM model with `past_key_values`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs) + + def test_causal_lm_model_past_with_attn_mask(self): + """Test the causal LM model with `past_key_values` and `attention_mask`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs) + + def test_causal_lm_model_past_with_large_inputs(self): + """Test the causal LM model with `past_key_values` and a longer decoder sequence length""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + """Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -460,7 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester: # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) - next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) # append to next input_ids and next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) @@ -488,9 +834,9 @@ def prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict( decoder_attention_mask=None, ): if attention_mask is None: - attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) + attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int32) if decoder_attention_mask is None: - decoder_attention_mask = tf.concat([tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8), tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8)], axis=-1) + decoder_attention_mask = tf.concat([tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int32), tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int32)], axis=-1) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, diff --git a/tests/bert/test_modeling_tf_bert.py b/tests/bert/test_modeling_tf_bert.py index 518870ceb8..611268337f 100644 --- a/tests/bert/test_modeling_tf_bert.py +++ b/tests/bert/test_modeling_tf_bert.py @@ -153,12 +153,12 @@ class TFBertModelTester: encoder_attention_mask, ) - def create_and_check_bert_model( + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFBertModel(config=config) inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} - sequence_output, pooled_output = model(inputs) + result = model(inputs) inputs = [input_ids, input_mask] result = model(inputs) @@ -168,10 +168,61 @@ class TFBertModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) - def create_and_check_bert_lm_head( + def create_and_check_causal_lm_base_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.is_decoder = True + + model = TFBertModel(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFBertModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + # Also check the case where encoder outputs are not passed + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_causal_lm_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.is_decoder = True + model = TFBertLMHeadModel(config=config) inputs = { "input_ids": input_ids, @@ -183,7 +234,261 @@ class TFBertModelTester: list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] ) - def create_and_check_bert_for_masked_lm( + def create_and_check_causal_lm_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFBertLMHeadModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + prediction_scores = result["logits"] + self.parent.assertListEqual( + list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_causal_lm_model_past( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFBertLMHeadModel(config=config) + + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_with_attn_mask( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFBertLMHeadModel(config=config) + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + past_key_values = outputs.past_key_values + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat( + [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], + axis=1, + ) + + output_from_no_past = model( + next_input_ids, + attention_mask=attn_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFBertLMHeadModel(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFBertLMHeadModel(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + encoder_hidden_states = encoder_hidden_states[:1, :, :] + encoder_attention_mask = encoder_attention_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFBertForMaskedLM(config=config) @@ -195,7 +500,7 @@ class TFBertModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_bert_for_next_sequence_prediction( + def create_and_check_for_next_sequence_prediction( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFBertForNextSentencePrediction(config=config) @@ -203,7 +508,7 @@ class TFBertModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, 2)) - def create_and_check_bert_for_pretraining( + def create_and_check_for_pretraining( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFBertForPreTraining(config=config) @@ -212,7 +517,7 @@ class TFBertModelTester: self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2)) - def create_and_check_bert_for_sequence_classification( + def create_and_check_for_sequence_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -226,7 +531,7 @@ class TFBertModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def create_and_check_bert_for_multiple_choice( + def create_and_check_for_multiple_choice( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_choices = self.num_choices @@ -242,7 +547,7 @@ class TFBertModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) - def create_and_check_bert_for_token_classification( + def create_and_check_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -255,7 +560,7 @@ class TFBertModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) - def create_and_check_bert_for_question_answering( + def create_and_check_for_question_answering( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFBertForQuestionAnswering(config=config) @@ -323,41 +628,84 @@ class TFBertModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC def test_config(self): self.config_tester.run_common_tests() - def test_bert_model(self): + def test_model(self): + """Test the base model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_causal_lm_base_model(self): + """Test the base model of the causal LM model + + is_deocder=True, no cross_attention, no encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs) + + def test_model_as_decoder(self): + """Test the base model as a decoder (of an encoder-decoder architecture) + + is_deocder=True + cross_attention + pass encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs) + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_causal_lm(self): + """Test the causal LM model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_lm_head(*config_and_inputs) + self.model_tester.create_and_check_causal_lm_model(*config_and_inputs) + + def test_causal_lm_model_as_decoder(self): + """Test the causal LM model as a decoder""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs) + + def test_causal_lm_model_past(self): + """Test causal LM model with `past_key_values`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs) + + def test_causal_lm_model_past_with_attn_mask(self): + """Test the causal LM model with `past_key_values` and `attention_mask`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs) + + def test_causal_lm_model_past_with_large_inputs(self): + """Test the causal LM model with `past_key_values` and a longer decoder sequence length""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + """Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs) + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) def test_for_next_sequence_prediction(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs) + self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs) def test_for_pretraining(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs) + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs) + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs) + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs) + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) def test_model_from_pretrained(self): model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random") diff --git a/tests/electra/test_modeling_tf_electra.py b/tests/electra/test_modeling_tf_electra.py index 9a5e6bf177..4593ecff61 100644 --- a/tests/electra/test_modeling_tf_electra.py +++ b/tests/electra/test_modeling_tf_electra.py @@ -20,7 +20,7 @@ from transformers import ElectraConfig, is_tf_available from transformers.testing_utils import require_tf, slow from ..test_configuration_common import ConfigTester -from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor if is_tf_available(): @@ -101,7 +101,34 @@ class TFElectraModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - def create_and_check_electra_model( + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFElectraModel(config=config) @@ -115,7 +142,277 @@ class TFElectraModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - def create_and_check_electra_for_masked_lm( + def create_and_check_causal_lm_base_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.is_decoder = True + + model = TFElectraModel(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFElectraModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + # Also check the case where encoder outputs are not passed + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_causal_lm_base_model_past( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFElectraModel(config=config) + + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_base_model_past_with_attn_mask( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFElectraModel(config=config) + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + past_key_values = outputs.past_key_values + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat( + [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], + axis=1, + ) + + output_from_no_past = model( + next_input_ids, + attention_mask=attn_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_base_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFElectraModel(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFElectraModel(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + encoder_hidden_states = encoder_hidden_states[:1, :, :] + encoder_attention_mask = encoder_attention_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFElectraForMaskedLM(config=config) @@ -123,7 +420,7 @@ class TFElectraModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_electra_for_pretraining( + def create_and_check_for_pretraining( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFElectraForPreTraining(config=config) @@ -131,7 +428,7 @@ class TFElectraModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length)) - def create_and_check_electra_for_sequence_classification( + def create_and_check_for_sequence_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -140,7 +437,7 @@ class TFElectraModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def create_and_check_electra_for_multiple_choice( + def create_and_check_for_multiple_choice( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_choices = self.num_choices @@ -156,7 +453,7 @@ class TFElectraModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) - def create_and_check_electra_for_question_answering( + def create_and_check_for_question_answering( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFElectraForQuestionAnswering(config=config) @@ -165,7 +462,7 @@ class TFElectraModelTester: self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) - def create_and_check_electra_for_token_classification( + def create_and_check_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -215,33 +512,70 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - def test_electra_model(self): + def test_model(self): + """Test the base model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_causal_lm_base_model(self): + """Test the base model of the causal LM model + + is_deocder=True, no cross_attention, no encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs) + + def test_model_as_decoder(self): + """Test the base model as a decoder (of an encoder-decoder architecture) + + is_deocder=True + cross_attention + pass encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + + def test_causal_lm_base_model_past(self): + """Test causal LM base model with `past_key_values`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model_past(*config_and_inputs) + + def test_causal_lm_base_model_past_with_attn_mask(self): + """Test the causal LM base model with `past_key_values` and `attention_mask`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model_past_with_attn_mask(*config_and_inputs) + + def test_causal_lm_base_model_past_with_large_inputs(self): + """Test the causal LM base model with `past_key_values` and a longer decoder sequence length""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model_past_large_inputs(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + """Similar to `test_causal_lm_base_model_past_with_large_inputs` but with cross-attention""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs) + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_pretraining(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs) + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs) + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs) + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs) + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs) + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) @slow def test_model_from_pretrained(self): diff --git a/tests/rembert/test_modeling_tf_rembert.py b/tests/rembert/test_modeling_tf_rembert.py index 81dad5d327..f8f17f30a9 100644 --- a/tests/rembert/test_modeling_tf_rembert.py +++ b/tests/rembert/test_modeling_tf_rembert.py @@ -168,7 +168,55 @@ class TFRemBertModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - def create_and_check_lm_head( + def create_and_check_causal_lm_base_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.is_decoder = True + + model = TFRemBertModel(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFRemBertModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + # Also check the case where encoder outputs are not passed + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_causal_lm_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.is_decoder = True @@ -183,6 +231,260 @@ class TFRemBertModelTester: list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] ) + def create_and_check_causal_lm_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFRemBertForCausalLM(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + prediction_scores = result["logits"] + self.parent.assertListEqual( + list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_causal_lm_model_past( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFRemBertForCausalLM(config=config) + + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_with_attn_mask( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFRemBertForCausalLM(config=config) + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + past_key_values = outputs.past_key_values + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat( + [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], + axis=1, + ) + + output_from_no_past = model( + next_input_ids, + attention_mask=attn_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFRemBertForCausalLM(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFRemBertForCausalLM(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + encoder_hidden_states = encoder_hidden_states[:1, :, :] + encoder_attention_mask = encoder_attention_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -295,16 +597,59 @@ class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase): self.config_tester.run_common_tests() def test_model(self): + """Test the base model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_causal_lm_base_model(self): + """Test the base model of the causal LM model + + is_deocder=True, no cross_attention, no encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs) + + def test_model_as_decoder(self): + """Test the base model as a decoder (of an encoder-decoder architecture) + + is_deocder=True + cross_attention + pass encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_causal_lm(self): + """Test the causal LM model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_lm_head(*config_and_inputs) + self.model_tester.create_and_check_causal_lm_model(*config_and_inputs) + + def test_causal_lm_model_as_decoder(self): + """Test the causal LM model as a decoder""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs) + + def test_causal_lm_model_past(self): + """Test causal LM model with `past_key_values`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs) + + def test_causal_lm_model_past_with_attn_mask(self): + """Test the causal LM model with `past_key_values` and `attention_mask`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs) + + def test_causal_lm_model_past_with_large_inputs(self): + """Test the causal LM model with `past_key_values` and a longer decoder sequence length""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + """Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/roberta/test_modeling_tf_roberta.py b/tests/roberta/test_modeling_tf_roberta.py index 5cda74ccc4..fa947d64f0 100644 --- a/tests/roberta/test_modeling_tf_roberta.py +++ b/tests/roberta/test_modeling_tf_roberta.py @@ -129,7 +129,7 @@ class TFRobertaModelTester: encoder_attention_mask, ) - def create_and_check_roberta_model( + def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFRobertaModel(config=config) @@ -143,21 +143,361 @@ class TFRobertaModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - def create_and_check_roberta_for_causal_lm( + def create_and_check_causal_lm_base_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): - model = TFRobertaForCausalLM(config=config) - result = model([input_ids, input_mask, token_type_ids]) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + config.is_decoder = True - def create_and_check_roberta_for_masked_lm( + model = TFRobertaModel(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFRobertaModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + # Also check the case where encoder outputs are not passed + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_causal_lm_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.is_decoder = True + + model = TFRobertaForCausalLM(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + } + prediction_scores = model(inputs)["logits"] + self.parent.assertListEqual( + list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_causal_lm_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFRobertaForCausalLM(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) + + prediction_scores = result["logits"] + self.parent.assertListEqual( + list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_causal_lm_model_past( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFRobertaForCausalLM(config=config) + + # special to `RobertaEmbeddings` in `Roberta`: + # - its `padding_idx` and its effect on `position_ids` + # (TFRobertaEmbeddings.create_position_ids_from_input_ids) + # - `1` here is `TFRobertaEmbeddings.padding_idx` + input_ids = tf.where(input_ids == 1, 2, input_ids) + + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_with_attn_mask( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFRobertaForCausalLM(config=config) + + # special to `RobertaEmbeddings` in `Roberta`: + # - its `padding_idx` and its effect on `position_ids` + # (TFRobertaEmbeddings.create_position_ids_from_input_ids) + # - `1` here is `TFRobertaEmbeddings.padding_idx` + # avoid `padding_idx` in the past + input_ids = tf.where(input_ids == 1, 2, input_ids) + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + past_key_values = outputs.past_key_values + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + # avoid `padding_idx` in the past + input_ids = tf.where(input_ids == 1, 2, input_ids) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat( + [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], + axis=1, + ) + + output_from_no_past = model( + next_input_ids, + attention_mask=attn_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True + ).hidden_states[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_causal_lm_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + + model = TFRobertaForCausalLM(config=config) + + # special to `RobertaEmbeddings` in `Roberta`: + # - its `padding_idx` and its effect on `position_ids` + # (TFRobertaEmbeddings.create_position_ids_from_input_ids) + # - `1` here is `TFRobertaEmbeddings.padding_idx` + # avoid `padding_idx` in the past + input_ids = tf.where(input_ids == 1, 2, input_ids) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFRobertaForCausalLM(config=config) + + # special to `RobertaEmbeddings` in `Roberta`: + # - its `padding_idx` and its effect on `position_ids` + # (TFRobertaEmbeddings.create_position_ids_from_input_ids) + # - `1` here is `TFRobertaEmbeddings.padding_idx` + # avoid `padding_idx` in the past + input_ids = tf.where(input_ids == 1, 2, input_ids) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + encoder_hidden_states = encoder_hidden_states[:1, :, :] + encoder_attention_mask = encoder_attention_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + ).hidden_states[0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + ).hidden_states[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFRobertaForMaskedLM(config=config) result = model([input_ids, input_mask, token_type_ids]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_roberta_for_token_classification( + def create_and_check_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels @@ -166,7 +506,7 @@ class TFRobertaModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) - def create_and_check_roberta_for_question_answering( + def create_and_check_for_question_answering( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = TFRobertaForQuestionAnswering(config=config) @@ -175,7 +515,7 @@ class TFRobertaModelTester: self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) - def create_and_check_roberta_for_multiple_choice( + def create_and_check_for_multiple_choice( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_choices = self.num_choices @@ -231,29 +571,72 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - def test_roberta_model(self): + def test_model(self): + """Test the base model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_causal_lm_base_model(self): + """Test the base model of the causal LM model + + is_deocder=True, no cross_attention, no encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs) + + def test_model_as_decoder(self): + """Test the base model as a decoder (of an encoder-decoder architecture) + + is_deocder=True + cross_attention + pass encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs) + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) def test_for_causal_lm(self): + """Test the causal LM model""" config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_causal_lm(*config_and_inputs) + self.model_tester.create_and_check_causal_lm_model(*config_and_inputs) + + def test_causal_lm_model_as_decoder(self): + """Test the causal LM model as a decoder""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs) + + def test_causal_lm_model_past(self): + """Test causal LM model with `past_key_values`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs) + + def test_causal_lm_model_past_with_attn_mask(self): + """Test the causal LM model with `past_key_values` and `attention_mask`""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs) + + def test_causal_lm_model_past_with_large_inputs(self): + """Test the causal LM model with `past_key_values` and a longer decoder sequence length""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + """Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention""" + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs) + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs) + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs) + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) @slow def test_model_from_pretrained(self):