From 0040469bb8e718f4ffafef829e497805df1aa1fb Mon Sep 17 00:00:00 2001 From: Tianlin Liu Date: Fri, 25 Aug 2023 17:36:37 +0200 Subject: [PATCH] Correct attention mask dtype for Flax GPT2 (#25636) * Correct attention mask dtype * reformat code * add a test for boolean mask * convert test to fast test * delete unwanted print * use assertTrue for testing --- .../models/gpt2/modeling_flax_gpt2.py | 4 ++- tests/models/gpt2/test_modeling_flax_gpt2.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 8973e081a3..50cfb5e112 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -753,7 +753,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask.astype("i4"), (0, 0) + ) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index 9bdc17fa19..1e24ad0b00 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -187,6 +187,26 @@ class FlaxGPT2ModelTester: diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + def check_bool_attention_mask_in_generation(self, model_class_name, config, input_ids, attention_mask): + model = model_class_name(config) + + output_int_att_mask = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=3, + ) + + output_bool_att_mask = model.generate( + input_ids=input_ids, + attention_mask=attention_mask.astype(bool), + max_new_tokens=3, + ) + + self.parent.assertTrue( + (output_bool_att_mask.sequences == output_int_att_mask.sequences).all(), + "Generated response differ between boolean and integer attention mask", + ) + @require_flax class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): @@ -208,6 +228,13 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes model_class_name, config, input_ids, attention_mask ) + def test_bool_attention_mask_in_generation(self): + for model_class_name in self.all_generative_model_classes: + config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_bool_attention_mask_in_generation( + model_class_name, config, input_ids, attention_mask + ) + @slow def test_batch_generation(self): tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="", padding_side="left")