Added support for seed in DataCollatorForWholeWordMask (#36903)
* Added support for seed in `DataCollatorForWholeWordMask`, and also wrote tests. Also fixed bugs where the code hardcoded values for mask replacement probability and random replacement probability, instead of using the values passed by the user. * formatting issues * Used better way to generate seed in TF. Made tests more consistent.
This commit is contained in:
@@ -445,6 +445,86 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# check if seed is respected in multiple workers situation
|
||||
features = [{"input_ids": list(range(1000))} for _ in range(10)]
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_3_input_ids = []
|
||||
batch_3_labels = []
|
||||
for batch in dataloader:
|
||||
batch_3_input_ids.append(batch["input_ids"])
|
||||
batch_3_labels.append(batch["labels"])
|
||||
|
||||
batch_3_input_ids = torch.stack(batch_3_input_ids)
|
||||
batch_3_labels = torch.stack(batch_3_labels)
|
||||
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_4_input_ids = []
|
||||
batch_4_labels = []
|
||||
for batch in dataloader:
|
||||
batch_4_input_ids.append(batch["input_ids"])
|
||||
batch_4_labels.append(batch["labels"])
|
||||
batch_4_input_ids = torch.stack(batch_4_input_ids)
|
||||
batch_4_labels = torch.stack(batch_4_labels)
|
||||
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
|
||||
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
|
||||
|
||||
# try with different seed
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
|
||||
)
|
||||
|
||||
batch_5_input_ids = []
|
||||
batch_5_labels = []
|
||||
for batch in dataloader:
|
||||
batch_5_input_ids.append(batch["input_ids"])
|
||||
batch_5_labels.append(batch["labels"])
|
||||
batch_5_input_ids = torch.stack(batch_5_input_ids)
|
||||
batch_5_labels = torch.stack(batch_5_labels)
|
||||
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
|
||||
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
@@ -1199,6 +1279,33 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# try with different seed
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="tf")
|
||||
batch_3 = data_collator(features)
|
||||
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
@@ -1920,6 +2027,32 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_1["labels"].shape, (2, 1000))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_2["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="np")
|
||||
batch_3 = data_collator(features)
|
||||
self.assertEqual(batch_3["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_3["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
|
||||
Reference in New Issue
Block a user