Add support for seed in DataCollatorForLanguageModeling (#36497)

Add support for `seed` in `DataCollatorForLanguageModeling`. Also wrote tests for verifying behaviour.
This commit is contained in:
gautham
2025-03-20 23:57:43 +05:30
committed by GitHub
parent ecd60d01c3
commit 9e771bf402
2 changed files with 248 additions and 22 deletions

View File

@@ -350,6 +350,86 @@ class DataCollatorIntegrationTest(unittest.TestCase):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_data_collator_for_language_modeling_with_seed(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
# check if seed is respected between two different DataCollatorForLanguageModeling instances
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42)
batch_1 = data_collator(features)
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42)
batch_2 = data_collator(features)
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
# check if seed is respected in multiple workers situation
features = [{"input_ids": list(range(1000))} for _ in range(10)]
dataloader = torch.utils.data.DataLoader(
features,
batch_size=2,
num_workers=2,
generator=torch.Generator().manual_seed(42),
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42),
)
batch_3_input_ids = []
batch_3_labels = []
for batch in dataloader:
batch_3_input_ids.append(batch["input_ids"])
batch_3_labels.append(batch["labels"])
batch_3_input_ids = torch.stack(batch_3_input_ids)
batch_3_labels = torch.stack(batch_3_labels)
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
dataloader = torch.utils.data.DataLoader(
features,
batch_size=2,
num_workers=2,
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42),
)
batch_4_input_ids = []
batch_4_labels = []
for batch in dataloader:
batch_4_input_ids.append(batch["input_ids"])
batch_4_labels.append(batch["labels"])
batch_4_input_ids = torch.stack(batch_4_input_ids)
batch_4_labels = torch.stack(batch_4_labels)
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
# try with different seed
dataloader = torch.utils.data.DataLoader(
features,
batch_size=2,
num_workers=2,
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=43),
)
batch_5_input_ids = []
batch_5_labels = []
for batch in dataloader:
batch_5_input_ids.append(batch["input_ids"])
batch_5_labels.append(batch["labels"])
batch_5_input_ids = torch.stack(batch_5_input_ids)
batch_5_labels = torch.stack(batch_5_labels)
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
def test_data_collator_for_whole_word_mask(self):
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
@@ -1077,6 +1157,33 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_data_collator_for_language_modeling_with_seed(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
# check if seed is respected between two different DataCollatorForLanguageModeling instances
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="tf")
batch_1 = data_collator(features)
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="tf")
batch_2 = data_collator(features)
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
# try with different seed
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=43, return_tensors="tf")
batch_3 = data_collator(features)
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
def test_data_collator_for_whole_word_mask(self):
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
@@ -1772,6 +1879,32 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_data_collator_for_language_modeling_with_seed(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
# check if seed is respected between two different DataCollatorForLanguageModeling instances
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="np")
batch_1 = data_collator(features)
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
self.assertEqual(batch_1["labels"].shape, (2, 1000))
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="np")
batch_2 = data_collator(features)
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
self.assertEqual(batch_2["labels"].shape, (2, 1000))
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=43, return_tensors="np")
batch_3 = data_collator(features)
self.assertEqual(batch_3["input_ids"].shape, (2, 1000))
self.assertEqual(batch_3["labels"].shape, (2, 1000))
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
def test_data_collator_for_whole_word_mask(self):
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")