Add support for seed in DataCollatorForLanguageModeling (#36497)
Add support for `seed` in `DataCollatorForLanguageModeling`. Also wrote tests for verifying behaviour.
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
@@ -787,6 +788,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
If set, will pad the sequence to a multiple of the provided value.
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
seed (`int`, *optional*):
|
||||
The seed to use for the random number generator for masking. If not provided, the global RNG will be used.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -827,6 +830,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
tf_experimental_compile: bool = False
|
||||
return_tensors: str = "pt"
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mlm and self.tokenizer.mask_token is None:
|
||||
@@ -852,12 +856,57 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
|
||||
|
||||
self.generator = None
|
||||
|
||||
def get_generator(self, seed):
|
||||
if self.return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
return torch.Generator().manual_seed(seed)
|
||||
elif self.return_tensors == "tf":
|
||||
import tensorflow as tf
|
||||
|
||||
return tf.random.Generator.from_seed(seed)
|
||||
else:
|
||||
import numpy as np
|
||||
|
||||
return np.random.default_rng(seed)
|
||||
|
||||
def create_rng(self):
|
||||
if mp.current_process().name == "MainProcess":
|
||||
# If we are in the main process, we create a generator object with the seed
|
||||
self.generator = self.get_generator(self.seed)
|
||||
else:
|
||||
# If we are in a worker process (i.e using multiprocessing), we need to set a unique seed for each
|
||||
# worker's generator, generated as the main seed + the worker's ID.
|
||||
# (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading)
|
||||
# Only PyTorch DataLoader allows us to access the worker ID, and so we check for this.
|
||||
# For other frameworks, we will throw an error.
|
||||
import torch
|
||||
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
error_string = (
|
||||
"Worker process information is not available for seeding the generator. This may be because",
|
||||
"you are using multiprocessing without using a PyTorch DataLoader. The `seed` parameter can",
|
||||
"only be used when using multiprocessing with a PyTorch DataLoader. Please either use a",
|
||||
"single process or use a PyTorch DataLoader with multiple workers.",
|
||||
)
|
||||
raise ValueError(error_string)
|
||||
|
||||
self.generator = self.get_generator(self.seed + worker_info.id)
|
||||
|
||||
@staticmethod
|
||||
def tf_bernoulli(shape, probability):
|
||||
def tf_bernoulli(shape, probability, generator=None):
|
||||
import tensorflow as tf
|
||||
|
||||
prob_matrix = tf.fill(shape, probability)
|
||||
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
|
||||
# if generator exists, use it to generate the random numbers
|
||||
# otherwise, use the global RNG
|
||||
if generator:
|
||||
return tf.cast(prob_matrix - generator.uniform(shape, 0, 1) >= 0, tf.bool)
|
||||
else:
|
||||
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
|
||||
|
||||
def tf_mask_tokens(
|
||||
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
|
||||
@@ -872,12 +921,12 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
input_shape = tf.shape(inputs)
|
||||
# 1 for a special token, 0 for a normal token in the special tokens mask
|
||||
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
||||
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
|
||||
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability, self.generator) & ~special_tokens_mask
|
||||
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
||||
labels = tf.where(masked_indices, inputs, -100)
|
||||
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices
|
||||
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices
|
||||
|
||||
inputs = tf.where(indices_replaced, mask_token_id, inputs)
|
||||
|
||||
@@ -891,9 +940,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
# random_replace_prob% of the time, we replace masked input tokens with random word
|
||||
indices_random = (
|
||||
self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced
|
||||
self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
||||
|
||||
if self.generator:
|
||||
random_words = self.generator.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
||||
else:
|
||||
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
||||
|
||||
inputs = tf.where(indices_random, random_words, inputs)
|
||||
|
||||
@@ -903,6 +958,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
import tensorflow as tf
|
||||
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
if isinstance(examples[0], Mapping):
|
||||
batch = pad_without_fast_tokenizer_warning(
|
||||
@@ -943,6 +1003,12 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
batch = pad_without_fast_tokenizer_warning(
|
||||
self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
@@ -983,11 +1049,14 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
special_tokens_mask = special_tokens_mask.bool()
|
||||
|
||||
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||
masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices
|
||||
indices_replaced = (
|
||||
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
|
||||
& masked_indices
|
||||
)
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
@@ -1001,11 +1070,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
# random_replace_prob% of the time, we replace masked input tokens with random word
|
||||
indices_random = (
|
||||
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool()
|
||||
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||
@@ -1013,6 +1082,12 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
batch = pad_without_fast_tokenizer_warning(
|
||||
self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
@@ -1052,13 +1127,21 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
probability_matrix[special_tokens_mask] = 0
|
||||
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
|
||||
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
||||
if self.generator:
|
||||
masked_indices = self.generator.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
||||
else:
|
||||
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = (
|
||||
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
if self.generator:
|
||||
indices_replaced = (
|
||||
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
else:
|
||||
indices_replaced = (
|
||||
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
||||
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
@@ -1069,14 +1152,24 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
indices_random = (
|
||||
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = np.random.randint(
|
||||
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
||||
)
|
||||
if self.generator:
|
||||
indices_random = (
|
||||
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = self.generator.integers(
|
||||
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
||||
)
|
||||
else:
|
||||
indices_random = (
|
||||
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = np.random.randint(
|
||||
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
||||
)
|
||||
inputs[indices_random] = random_words
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
|
||||
@@ -350,6 +350,86 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
pad_features = [list(range(5)), list(range(10))]
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_data_collator_for_language_modeling_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForLanguageModeling instances
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42)
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42)
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# check if seed is respected in multiple workers situation
|
||||
features = [{"input_ids": list(range(1000))} for _ in range(10)]
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_3_input_ids = []
|
||||
batch_3_labels = []
|
||||
for batch in dataloader:
|
||||
batch_3_input_ids.append(batch["input_ids"])
|
||||
batch_3_labels.append(batch["labels"])
|
||||
|
||||
batch_3_input_ids = torch.stack(batch_3_input_ids)
|
||||
batch_3_labels = torch.stack(batch_3_labels)
|
||||
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_4_input_ids = []
|
||||
batch_4_labels = []
|
||||
for batch in dataloader:
|
||||
batch_4_input_ids.append(batch["input_ids"])
|
||||
batch_4_labels.append(batch["labels"])
|
||||
batch_4_input_ids = torch.stack(batch_4_input_ids)
|
||||
batch_4_labels = torch.stack(batch_4_labels)
|
||||
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
|
||||
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
|
||||
|
||||
# try with different seed
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=43),
|
||||
)
|
||||
|
||||
batch_5_input_ids = []
|
||||
batch_5_labels = []
|
||||
for batch in dataloader:
|
||||
batch_5_input_ids.append(batch["input_ids"])
|
||||
batch_5_labels.append(batch["labels"])
|
||||
batch_5_input_ids = torch.stack(batch_5_input_ids)
|
||||
batch_5_labels = torch.stack(batch_5_labels)
|
||||
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
|
||||
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||
@@ -1077,6 +1157,33 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
pad_features = [list(range(5)), list(range(10))]
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_data_collator_for_language_modeling_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForLanguageModeling instances
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="tf")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="tf")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# try with different seed
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=43, return_tensors="tf")
|
||||
batch_3 = data_collator(features)
|
||||
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
|
||||
@@ -1772,6 +1879,32 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
pad_features = [list(range(5)), list(range(10))]
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_data_collator_for_language_modeling_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForLanguageModeling instances
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="np")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_1["labels"].shape, (2, 1000))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="np")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_2["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, seed=43, return_tensors="np")
|
||||
batch_3 = data_collator(features)
|
||||
self.assertEqual(batch_3["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_3["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
||||
|
||||
Reference in New Issue
Block a user