From 8332327dca1461fc9a1b526f1f73936acda4f4a7 Mon Sep 17 00:00:00 2001 From: Rahul Nadkarni Date: Mon, 29 Nov 2021 08:30:17 -0800 Subject: [PATCH] Fix sentinel token IDs in data collator for Flax T5 pretraining script (#14477) --- examples/flax/language-modeling/run_t5_mlm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index b78dc0431a..b62a144449 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -291,7 +291,7 @@ class FlaxDataCollatorForT5MLM: start_indices[:, 0] = mask_indices[:, 0] sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) - sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0) + sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0) sentinel_ids -= mask_indices - start_indices return sentinel_ids