Correct attention mask dtype for Flax GPT2 (#25636)
* Correct attention mask dtype * reformat code * add a test for boolean mask * convert test to fast test * delete unwanted print * use assertTrue for testing
This commit is contained in:
@@ -753,7 +753,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
|||||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
extended_attention_mask = lax.dynamic_update_slice(
|
||||||
|
extended_attention_mask, attention_mask.astype("i4"), (0, 0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||||
|
|
||||||
|
|||||||
@@ -187,6 +187,26 @@ class FlaxGPT2ModelTester:
|
|||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
def check_bool_attention_mask_in_generation(self, model_class_name, config, input_ids, attention_mask):
|
||||||
|
model = model_class_name(config)
|
||||||
|
|
||||||
|
output_int_att_mask = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_new_tokens=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_bool_att_mask = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask.astype(bool),
|
||||||
|
max_new_tokens=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent.assertTrue(
|
||||||
|
(output_bool_att_mask.sequences == output_int_att_mask.sequences).all(),
|
||||||
|
"Generated response differ between boolean and integer attention mask",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||||
@@ -208,6 +228,13 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
|||||||
model_class_name, config, input_ids, attention_mask
|
model_class_name, config, input_ids, attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_bool_attention_mask_in_generation(self):
|
||||||
|
for model_class_name in self.all_generative_model_classes:
|
||||||
|
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_bool_attention_mask_in_generation(
|
||||||
|
model_class_name, config, input_ids, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_batch_generation(self):
|
def test_batch_generation(self):
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||||
|
|||||||
Reference in New Issue
Block a user