Clean up data collators and datasets (#8308)

* Clean up data collators and datasets

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Remove needless clone

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-11-04 17:24:49 -05:00
committed by GitHub
parent b1d3e95eb5
commit 9c4aa4ac1a
6 changed files with 136 additions and 197 deletions

View File

@@ -12,9 +12,7 @@ if is_torch_available():
from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorWithPadding,
default_data_collator,
@@ -201,13 +199,16 @@ class DataCollatorIntegrationTest(unittest.TestCase):
def test_nsp(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)]
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
features = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
data_collator = DataCollatorForLanguageModeling(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
def test_sop(self):
@@ -216,11 +217,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
{
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
"sentence_order_label": torch.tensor(i),
"sentence_order_label": i,
}
for i in range(2)
]
data_collator = DataCollatorForSOP(tokenizer)
data_collator = DataCollatorForLanguageModeling(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))