@@ -42,10 +42,10 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||||||
# Special handling for labels.
|
# Special handling for labels.
|
||||||
# Ensure that tensor is created with the correct type
|
# Ensure that tensor is created with the correct type
|
||||||
# (it should be automatically the case, but let's make sure of it.)
|
# (it should be automatically the case, but let's make sure of it.)
|
||||||
if "label" in first:
|
if "label" in first and first["label"] is not None:
|
||||||
dtype = torch.long if type(first["label"]) is int else torch.float
|
dtype = torch.long if type(first["label"]) is int else torch.float
|
||||||
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
||||||
elif "label_ids" in first:
|
elif "label_ids" in first and first["label_ids"] is not None:
|
||||||
if isinstance(first["label_ids"], torch.Tensor):
|
if isinstance(first["label_ids"], torch.Tensor):
|
||||||
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
|
|||||||
@require_torch
|
@require_torch
|
||||||
class DataCollatorIntegrationTest(unittest.TestCase):
|
class DataCollatorIntegrationTest(unittest.TestCase):
|
||||||
def test_default_with_dict(self):
|
def test_default_with_dict(self):
|
||||||
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
batch = default_data_collator(features)
|
batch = default_data_collator(features)
|
||||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
@@ -39,12 +39,24 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||||
|
|
||||||
# Features can already be tensors
|
# Features can already be tensors
|
||||||
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
|
features = [{"label": i, "inputs": torch.randint(10, [10])} for i in range(8)]
|
||||||
batch = default_data_collator(features)
|
batch = default_data_collator(features)
|
||||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
||||||
|
|
||||||
|
def test_default_with_no_labels(self):
|
||||||
|
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features)
|
||||||
|
self.assertTrue("labels" not in batch)
|
||||||
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||||
|
|
||||||
|
# With label_ids
|
||||||
|
features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features)
|
||||||
|
self.assertTrue("labels" not in batch)
|
||||||
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||||
|
|
||||||
def test_default_classification(self):
|
def test_default_classification(self):
|
||||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
|||||||
Reference in New Issue
Block a user