handle numpy inputs in whole word mask data collator (#22032)
This commit is contained in:
@@ -883,6 +883,8 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
|
|
||||||
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
if isinstance(examples[0], Mapping):
|
if isinstance(examples[0], Mapping):
|
||||||
input_ids = [e["input_ids"] for e in examples]
|
input_ids = [e["input_ids"] for e in examples]
|
||||||
else:
|
else:
|
||||||
@@ -907,7 +909,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
ref_tokens[i] = "##" + ref_tokens[i]
|
ref_tokens[i] = "##" + ref_tokens[i]
|
||||||
mask_labels.append(self._whole_word_mask(ref_tokens))
|
mask_labels.append(self._whole_word_mask(ref_tokens))
|
||||||
batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
||||||
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
|
inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
|
||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
|
|
||||||
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
|
|||||||
@@ -271,12 +271,17 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
def test_data_collator_for_whole_word_mask(self):
|
def test_data_collator_for_whole_word_mask(self):
|
||||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
||||||
|
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||||
batch = data_collator(features)
|
|
||||||
|
|
||||||
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||||
|
|
||||||
|
# Features can already be tensors
|
||||||
|
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||||
|
batch = data_collator(features)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||||
|
|
||||||
@@ -553,12 +558,17 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
def test_data_collator_for_whole_word_mask(self):
|
def test_data_collator_for_whole_word_mask(self):
|
||||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
||||||
|
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
|
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
|
||||||
batch = data_collator(features)
|
|
||||||
|
|
||||||
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
# Features can already be tensors
|
||||||
|
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||||
|
batch = data_collator(features)
|
||||||
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
@@ -825,12 +835,17 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
def test_data_collator_for_whole_word_mask(self):
|
def test_data_collator_for_whole_word_mask(self):
|
||||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
||||||
|
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
||||||
batch = data_collator(features)
|
|
||||||
|
|
||||||
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
# Features can already be tensors
|
||||||
|
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||||
|
batch = data_collator(features)
|
||||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user