handle numpy inputs in whole word mask data collator (#22032)
This commit is contained in:
@@ -271,12 +271,17 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||
batch = data_collator(features)
|
||||
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
|
||||
@@ -553,12 +558,17 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
|
||||
batch = data_collator(features)
|
||||
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||
|
||||
@@ -825,12 +835,17 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
||||
batch = data_collator(features)
|
||||
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user