minimal fixes to run DataCollatorForWholeWordMask with return_tensors="np" and return_tensors="tf" (#13891)
* minimal fixes to run DataCollatorForWholeWordMask with return_tensors="np" and return_tensors="tf" * more consinstent implementation for numpy_mask_tokens
This commit is contained in:
@@ -883,14 +883,14 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
|
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
|
||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
|
|
||||||
def np_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
input_ids = [e["input_ids"] for e in examples]
|
input_ids = [e["input_ids"] for e in examples]
|
||||||
else:
|
else:
|
||||||
input_ids = examples
|
input_ids = examples
|
||||||
examples = [{"input_ids": e} for e in examples]
|
examples = [{"input_ids": e} for e in examples]
|
||||||
|
|
||||||
batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
||||||
|
|
||||||
mask_labels = []
|
mask_labels = []
|
||||||
for e in examples:
|
for e in examples:
|
||||||
@@ -1009,15 +1009,15 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
||||||
)
|
)
|
||||||
labels = inputs.clone()
|
labels = tf.identity(inputs)
|
||||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
|
|
||||||
masked_indices = tf.cast(mask_labels, tf.bool)
|
masked_indices = tf.cast(mask_labels, tf.bool)
|
||||||
|
|
||||||
special_tokens_mask = [
|
special_tokens_mask = [
|
||||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
|
||||||
]
|
]
|
||||||
masked_indices = masked_indices & ~tf.convert_to_tensor(special_tokens_mask, dtype=tf.bool)
|
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer._pad_token is not None:
|
||||||
padding_mask = inputs == self.tokenizer.pad_token_id
|
padding_mask = inputs == self.tokenizer.pad_token_id
|
||||||
masked_indices = masked_indices & ~padding_mask
|
masked_indices = masked_indices & ~padding_mask
|
||||||
@@ -1073,9 +1073,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
indices_random = (
|
indices_random = (
|
||||||
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
|
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
|
||||||
)
|
)
|
||||||
random_words = np.random.randint(
|
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
||||||
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
|
||||||
)
|
|
||||||
inputs[indices_random] = random_words[indices_random]
|
inputs[indices_random] = random_words[indices_random]
|
||||||
|
|
||||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from transformers import (
|
|||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
DataCollatorForPermutationLanguageModeling,
|
DataCollatorForPermutationLanguageModeling,
|
||||||
DataCollatorForTokenClassification,
|
DataCollatorForTokenClassification,
|
||||||
|
DataCollatorForWholeWordMask,
|
||||||
DataCollatorWithPadding,
|
DataCollatorWithPadding,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
@@ -224,6 +225,16 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
pad_features = [list(range(5)), list(range(10))]
|
pad_features = [list(range(5)), list(range(10))]
|
||||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
def test_data_collator_for_whole_word_mask(self):
|
||||||
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||||
|
|
||||||
def test_plm(self):
|
def test_plm(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
@@ -488,6 +499,16 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
pad_features = [list(range(5)), list(range(10))]
|
pad_features = [list(range(5)), list(range(10))]
|
||||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
def test_data_collator_for_whole_word_mask(self):
|
||||||
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
def test_plm(self):
|
def test_plm(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
@@ -750,6 +771,16 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
pad_features = [list(range(5)), list(range(10))]
|
pad_features = [list(range(5)), list(range(10))]
|
||||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
def test_data_collator_for_whole_word_mask(self):
|
||||||
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
def test_plm(self):
|
def test_plm(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
|||||||
Reference in New Issue
Block a user