From 10b76987fc45a36aab5c5549dd607e2475c32e73 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 4 Mar 2022 17:04:43 +0100 Subject: [PATCH] [FlaxT5 Example] fix flax t5 example pretraining (#15835) --- examples/flax/language-modeling/run_t5_mlm_flax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 2845d9ce58..e0ea0fa3fb 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -368,7 +368,9 @@ class FlaxDataCollatorForT5MLM: batch_size = input_ids.shape[0] input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) - input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1)) + # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are + # masked tokens coming after sentinel tokens and should be removed + input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) input_ids = np.concatenate( [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1 )