From 48385aa4f4ccb6983d3c2beabca45e54a4c514d2 Mon Sep 17 00:00:00 2001 From: gautham <91133513+capemox@users.noreply.github.com> Date: Mon, 24 Mar 2025 22:27:17 +0530 Subject: [PATCH] Added support for seed in `DataCollatorForWholeWordMask` (#36903) * Added support for seed in `DataCollatorForWholeWordMask`, and also wrote tests. Also fixed bugs where the code hardcoded values for mask replacement probability and random replacement probability, instead of using the values passed by the user. * formatting issues * Used better way to generate seed in TF. Made tests more consistent. --- src/transformers/data/data_collator.py | 142 +++++++++++++++++++++---- tests/trainer/test_data_collator.py | 133 +++++++++++++++++++++++ 2 files changed, 253 insertions(+), 22 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index dce6991365..07490a25f9 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1193,6 +1193,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): """ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + if self.seed and self.generator is None: + # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. + # If no seed supplied, we will use the global RNG + self.create_rng() + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: @@ -1223,6 +1228,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: import tensorflow as tf + if self.seed and self.generator is None: + # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. + # If no seed supplied, we will use the global RNG + self.create_rng() + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: @@ -1251,6 +1261,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): return {"input_ids": inputs, "labels": labels} def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + if self.seed and self.generator is None: + # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. + # If no seed supplied, we will use the global RNG + self.create_rng() + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: @@ -1278,6 +1293,30 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask) return {"input_ids": inputs, "labels": labels} + def _shuffle(self, cand_indexes): + # if no seed, just use random's shuffle + if self.seed is None: + random.shuffle(cand_indexes) + return cand_indexes + + # if seed is provided, use the generator to shuffle + if self.return_tensors == "pt": + import torch + + indices = torch.randperm(len(cand_indexes), generator=self.generator) + return [cand_indexes[i] for i in indices] + + elif self.return_tensors == "tf": + import tensorflow as tf + + seed = self.generator.make_seeds(2)[0] + indices = tf.random.experimental.stateless_shuffle(tf.range(len(cand_indexes)), seed=seed).numpy().tolist() + return [cand_indexes[i] for i in indices] + + elif self.return_tensors == "np": + self.generator.shuffle(cand_indexes) + return cand_indexes + def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): """ Get 0/1 labels for masked tokens with whole word mask proxy @@ -1298,7 +1337,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): else: cand_indexes.append([i]) - random.shuffle(cand_indexes) + cand_indexes = self._shuffle(cand_indexes) num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability)))) masked_lms = [] covered_indexes = set() @@ -1346,16 +1385,32 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): masked_indices = probability_matrix.bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = ( + torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool() + & masked_indices + ) inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) - # 10% of the time, we replace masked input tokens with random word - indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced - random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) + if self.mask_replace_prob == 1 or self.random_replace_prob == 0: + return inputs, labels + + remaining_prob = 1 - self.mask_replace_prob + # scaling the random_replace_prob to the remaining probability for example if + # mask_replace_prob = 0.8 and random_replace_prob = 0.1, + # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 + random_replace_prob_scaled = self.random_replace_prob / remaining_prob + + # random_replacement_prob% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool() + & masked_indices + & ~indices_replaced + ) + random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator) inputs[indices_random] = random_words[indices_random] - # The rest of the time (10% of the time) we keep the masked input tokens unchanged + # The rest of the time ((1-random_replacement_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged return inputs, labels def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: @@ -1387,17 +1442,35 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens labels = tf.where(masked_indices, inputs, -100) - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices + # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs) - # 10% of the time, we replace masked input tokens with random word - indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced - random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) + if self.mask_replace_prob == 1 or self.random_replace_prob == 0: + return inputs, labels + + remaining_prob = 1 - self.mask_replace_prob + # scaling the random_replace_prob to the remaining probability for example if + # mask_replace_prob = 0.8 and random_replace_prob = 0.1, + # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 + random_replace_prob_scaled = self.random_replace_prob / remaining_prob + + # random_replace_prob% of the time, we replace masked input tokens with random word + indices_random = ( + self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator) + & masked_indices + & ~indices_replaced + ) + + if self.generator: + random_words = self.generator.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) + else: + random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) + inputs = tf.where(indices_random, random_words, inputs) - # The rest of the time (10% of the time) we keep the masked input tokens unchanged + # The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged return inputs, labels def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: @@ -1425,19 +1498,44 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): labels[~masked_indices] = -100 # We only compute loss on masked tokens - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices + # mask_replacement_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + if self.generator: + indices_replaced = ( + self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices + ) + else: + indices_replaced = ( + np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices + ) inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) - # 10% of the time, we replace masked input tokens with random word - # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced - indices_random = ( - np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced - ) - random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64) + if self.mask_replace_prob == 1 or self.random_replace_prob == 0: + return inputs, labels + + remaining_prob = 1 - self.mask_replace_prob + # scaling the random_replace_prob to the remaining probability for example if + # mask_replace_prob = 0.8 and random_replace_prob = 0.1, + # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 + random_replace_prob_scaled = self.random_replace_prob / remaining_prob + + if self.generator: + indices_random = ( + self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) + & masked_indices + & ~indices_replaced + ) + random_words = self.generator.integers(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64) + else: + indices_random = ( + np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) + & masked_indices + & ~indices_replaced + ) + random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64) + inputs[indices_random] = random_words[indices_random] - # The rest of the time (10% of the time) we keep the masked input tokens unchanged + # The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged return inputs, labels diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index ca88b3c79c..a88641ca16 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -445,6 +445,86 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) + def test_data_collator_for_whole_word_mask_with_seed(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}] + + # check if seed is respected between two different DataCollatorForWholeWordMask instances + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42) + batch_1 = data_collator(features) + self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000))) + self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000))) + + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42) + batch_2 = data_collator(features) + self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000))) + self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000))) + + self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"])) + self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"])) + + # check if seed is respected in multiple workers situation + features = [{"input_ids": list(range(1000))} for _ in range(10)] + dataloader = torch.utils.data.DataLoader( + features, + batch_size=2, + num_workers=2, + generator=torch.Generator().manual_seed(42), + collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42), + ) + + batch_3_input_ids = [] + batch_3_labels = [] + for batch in dataloader: + batch_3_input_ids.append(batch["input_ids"]) + batch_3_labels.append(batch["labels"]) + + batch_3_input_ids = torch.stack(batch_3_input_ids) + batch_3_labels = torch.stack(batch_3_labels) + self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000))) + self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000))) + + dataloader = torch.utils.data.DataLoader( + features, + batch_size=2, + num_workers=2, + collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42), + ) + + batch_4_input_ids = [] + batch_4_labels = [] + for batch in dataloader: + batch_4_input_ids.append(batch["input_ids"]) + batch_4_labels.append(batch["labels"]) + batch_4_input_ids = torch.stack(batch_4_input_ids) + batch_4_labels = torch.stack(batch_4_labels) + self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000))) + self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000))) + + self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids)) + self.assertTrue(torch.all(batch_3_labels == batch_4_labels)) + + # try with different seed + dataloader = torch.utils.data.DataLoader( + features, + batch_size=2, + num_workers=2, + collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43), + ) + + batch_5_input_ids = [] + batch_5_labels = [] + for batch in dataloader: + batch_5_input_ids.append(batch["input_ids"]) + batch_5_labels.append(batch["labels"]) + batch_5_input_ids = torch.stack(batch_5_input_ids) + batch_5_labels = torch.stack(batch_5_labels) + self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000))) + self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000))) + + self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids)) + self.assertFalse(torch.all(batch_3_labels == batch_5_labels)) + def test_plm(self): tokenizer = BertTokenizer(self.vocab_file) no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] @@ -1199,6 +1279,33 @@ class TFDataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10]) self.assertEqual(batch["labels"].shape.as_list(), [2, 10]) + def test_data_collator_for_whole_word_mask_with_seed(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}] + + # check if seed is respected between two different DataCollatorForWholeWordMask instances + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf") + batch_1 = data_collator(features) + self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000]) + self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000]) + + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf") + batch_2 = data_collator(features) + self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000]) + self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000]) + + self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"])) + self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"])) + + # try with different seed + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="tf") + batch_3 = data_collator(features) + self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000]) + self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000]) + + self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"])) + self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"])) + def test_plm(self): tokenizer = BertTokenizer(self.vocab_file) no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] @@ -1920,6 +2027,32 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["input_ids"].shape, (2, 10)) self.assertEqual(batch["labels"].shape, (2, 10)) + def test_data_collator_for_whole_word_mask_with_seed(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}] + + # check if seed is respected between two different DataCollatorForWholeWordMask instances + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np") + batch_1 = data_collator(features) + self.assertEqual(batch_1["input_ids"].shape, (2, 1000)) + self.assertEqual(batch_1["labels"].shape, (2, 1000)) + + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np") + batch_2 = data_collator(features) + self.assertEqual(batch_2["input_ids"].shape, (2, 1000)) + self.assertEqual(batch_2["labels"].shape, (2, 1000)) + + self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"])) + self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"])) + + data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="np") + batch_3 = data_collator(features) + self.assertEqual(batch_3["input_ids"].shape, (2, 1000)) + self.assertEqual(batch_3["labels"].shape, (2, 1000)) + + self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"])) + self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"])) + def test_plm(self): tokenizer = BertTokenizer(self.vocab_file) no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]