Make default_data_collator more flexible and deprecate old behavior (#5060)
* Make default_data_collator more flexible * Accept tensors for all features * Document code * Refactor * Formatting
This commit is contained in:
@@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
|
||||
|
||||
@require_torch
|
||||
class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
def test_default_with_dict(self):
|
||||
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||
batch = default_data_collator(features)
|
||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||
|
||||
# With label_ids
|
||||
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||
batch = default_data_collator(features)
|
||||
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
|
||||
batch = default_data_collator(features)
|
||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
||||
|
||||
def test_default_classification(self):
|
||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
Reference in New Issue
Block a user