Enhance DataCollatorForLanguageModeling with Configurable Token Replacement Probabilities (#35251)
* DataCollatorForLanguageModeling class was updated with new parameters that provides more control over the token masking and relacing * DataCollatorForLanguageModeling class was updated with new parameters that provides more control over the token masking and relacing * Addressed review comments, modified the docstring and made a test for the DataCollatorForLanguageModeling
This commit is contained in:
committed by
GitHub
parent
b0cdbd9119
commit
c61fcde910
@@ -1020,6 +1020,52 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(tf.reduce_any(masked_tokens))
|
||||
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
||||
|
||||
def test_probability_sum_error(self):
|
||||
"""Test that the sum of mask_replace_prob and random_replace_prob exceeding 1 raises an error."""
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
with self.assertRaises(ValueError):
|
||||
DataCollatorForLanguageModeling(tokenizer=tokenizer, mask_replace_prob=0.9, random_replace_prob=0.2)
|
||||
|
||||
def test_all_mask_replacement(self):
|
||||
"""Test behavior when mask_replace_prob=1."""
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
|
||||
# pytorch call
|
||||
collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="pt"
|
||||
)
|
||||
|
||||
inputs = torch.tensor([0, 1, 2, 3, 4, 5])
|
||||
features = [{"input_ids": inputs} for _ in range(8)]
|
||||
batch = collator(features)
|
||||
|
||||
# confirm that every token is either the original token or [MASK]
|
||||
self.assertTrue(torch.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))
|
||||
|
||||
# tf call
|
||||
collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="tf"
|
||||
)
|
||||
inputs = tf.constant([0, 1, 2, 3, 4, 5])
|
||||
features = [{"input_ids": inputs} for _ in range(8)]
|
||||
batch = collator(features)
|
||||
|
||||
# confirm that every token is either the original token or [MASK]
|
||||
self.assertTrue(
|
||||
tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))
|
||||
)
|
||||
|
||||
# numpy call
|
||||
collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="np"
|
||||
)
|
||||
inputs = np.array([0, 1, 2, 3, 4, 5])
|
||||
features = [{"input_ids": inputs} for _ in range(8)]
|
||||
batch = collator(features)
|
||||
|
||||
# confirm that every token is either the original token or [MASK]
|
||||
self.assertTrue(np.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))
|
||||
|
||||
def test_data_collator_for_language_modeling(self):
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||
|
||||
Reference in New Issue
Block a user