Enhance DataCollatorForLanguageModeling with Configurable Token Replacement Probabilities (#35251)

* DataCollatorForLanguageModeling class was updated with new parameters that provides more control over the token masking and relacing

* DataCollatorForLanguageModeling class was updated with new parameters that provides more control over the token masking and relacing

* Addressed review comments, modified the docstring and made a test for the DataCollatorForLanguageModeling
This commit is contained in:
Mahdi Baghbanzadeh
2025-01-14 12:01:10 -05:00
committed by GitHub
parent b0cdbd9119
commit c61fcde910
2 changed files with 139 additions and 20 deletions

View File

@@ -1020,6 +1020,52 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
self.assertTrue(tf.reduce_any(masked_tokens))
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
def test_probability_sum_error(self):
"""Test that the sum of mask_replace_prob and random_replace_prob exceeding 1 raises an error."""
tokenizer = BertTokenizer(self.vocab_file)
with self.assertRaises(ValueError):
DataCollatorForLanguageModeling(tokenizer=tokenizer, mask_replace_prob=0.9, random_replace_prob=0.2)
def test_all_mask_replacement(self):
"""Test behavior when mask_replace_prob=1."""
tokenizer = BertTokenizer(self.vocab_file)
# pytorch call
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="pt"
)
inputs = torch.tensor([0, 1, 2, 3, 4, 5])
features = [{"input_ids": inputs} for _ in range(8)]
batch = collator(features)
# confirm that every token is either the original token or [MASK]
self.assertTrue(torch.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))
# tf call
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="tf"
)
inputs = tf.constant([0, 1, 2, 3, 4, 5])
features = [{"input_ids": inputs} for _ in range(8)]
batch = collator(features)
# confirm that every token is either the original token or [MASK]
self.assertTrue(
tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))
)
# numpy call
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="np"
)
inputs = np.array([0, 1, 2, 3, 4, 5])
features = [{"input_ids": inputs} for _ in range(8)]
batch = collator(features)
# confirm that every token is either the original token or [MASK]
self.assertTrue(np.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))
def test_data_collator_for_language_modeling(self):
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]