Clean up data collators and datasets (#8308)
* Clean up data collators and datasets * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Remove needless clone Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -12,9 +12,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForNextSentencePrediction,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSOP,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorWithPadding,
|
||||
default_data_collator,
|
||||
@@ -201,13 +199,16 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_nsp(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)]
|
||||
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
|
||||
features = [
|
||||
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
|
||||
for i in range(2)
|
||||
]
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer)
|
||||
batch = data_collator(features)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
||||
|
||||
def test_sop(self):
|
||||
@@ -216,11 +217,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
{
|
||||
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
|
||||
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
|
||||
"sentence_order_label": torch.tensor(i),
|
||||
"sentence_order_label": i,
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
data_collator = DataCollatorForSOP(tokenizer)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer)
|
||||
batch = data_collator(features)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
|
||||
|
||||
Reference in New Issue
Block a user