Correct attention mask dtype for Flax GPT2 (#25636)
* Correct attention mask dtype * reformat code * add a test for boolean mask * convert test to fast test * delete unwanted print * use assertTrue for testing
This commit is contained in:
@@ -753,7 +753,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
extended_attention_mask = lax.dynamic_update_slice(
|
||||
extended_attention_mask, attention_mask.astype("i4"), (0, 0)
|
||||
)
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
|
||||
@@ -187,6 +187,26 @@ class FlaxGPT2ModelTester:
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_bool_attention_mask_in_generation(self, model_class_name, config, input_ids, attention_mask):
|
||||
model = model_class_name(config)
|
||||
|
||||
output_int_att_mask = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=3,
|
||||
)
|
||||
|
||||
output_bool_att_mask = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask.astype(bool),
|
||||
max_new_tokens=3,
|
||||
)
|
||||
|
||||
self.parent.assertTrue(
|
||||
(output_bool_att_mask.sequences == output_int_att_mask.sequences).all(),
|
||||
"Generated response differ between boolean and integer attention mask",
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
@@ -208,6 +228,13 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
def test_bool_attention_mask_in_generation(self):
|
||||
for model_class_name in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_bool_attention_mask_in_generation(
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||
|
||||
Reference in New Issue
Block a user