From 2f4cdd97f5b837858f33d7d1095fba4b90871f57 Mon Sep 17 00:00:00 2001 From: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> Date: Fri, 10 Mar 2023 08:50:29 -0700 Subject: [PATCH] handle numpy inputs in whole word mask data collator (#22032) --- src/transformers/data/data_collator.py | 4 +++- tests/trainer/test_data_collator.py | 33 +++++++++++++++++++------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 587a4f4d00..cd36358875 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -883,6 +883,8 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): return {"input_ids": inputs, "labels": labels} def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + import tensorflow as tf + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: @@ -907,7 +909,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ref_tokens[i] = "##" + ref_tokens[i] mask_labels.append(self._whole_word_mask(ref_tokens)) batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) - inputs, labels = self.tf_mask_tokens(batch_input, batch_mask) + inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask) return {"input_ids": inputs, "labels": labels} def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index 39277ca8cc..f5104cd375 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -271,12 +271,17 @@ class DataCollatorIntegrationTest(unittest.TestCase): self._test_no_pad_and_pad(no_pad_features, pad_features) def test_data_collator_for_whole_word_mask(self): - features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] - tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt") - batch = data_collator(features) + features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) + + # Features can already be tensors + features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}] + batch = data_collator(features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) @@ -553,12 +558,17 @@ class TFDataCollatorIntegrationTest(unittest.TestCase): self._test_no_pad_and_pad(no_pad_features, pad_features) def test_data_collator_for_whole_word_mask(self): - features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] - tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf") - batch = data_collator(features) + features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10]) + self.assertEqual(batch["labels"].shape.as_list(), [2, 10]) + + # Features can already be tensors + features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}] + batch = data_collator(features) self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10]) self.assertEqual(batch["labels"].shape.as_list(), [2, 10]) @@ -825,12 +835,17 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase): self._test_no_pad_and_pad(no_pad_features, pad_features) def test_data_collator_for_whole_word_mask(self): - features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] - tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np") - batch = data_collator(features) + features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, (2, 10)) + self.assertEqual(batch["labels"].shape, (2, 10)) + + # Features can already be tensors + features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}] + batch = data_collator(features) self.assertEqual(batch["input_ids"].shape, (2, 10)) self.assertEqual(batch["labels"].shape, (2, 10))