From 5f721ad6e48c9d846de25c3fefa0e50a306cbf10 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 18 Jun 2020 19:20:04 -0400 Subject: [PATCH] Fix #5114 (#5122) --- src/transformers/data/data_collator.py | 4 ++-- tests/test_trainer.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 5e014d338b..d575f48d6c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -42,10 +42,10 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten # Special handling for labels. # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) - if "label" in first: + if "label" in first and first["label"] is not None: dtype = torch.long if type(first["label"]) is int else torch.float batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) - elif "label_ids" in first: + elif "label_ids" in first and first["label_ids"] is not None: if isinstance(first["label_ids"], torch.Tensor): batch["labels"] = torch.stack([f["label_ids"] for f in features]) else: diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 47cfa89918..7aead04429 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -25,7 +25,7 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" @require_torch class DataCollatorIntegrationTest(unittest.TestCase): def test_default_with_dict(self): - features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] + features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] batch = default_data_collator(features) self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8))))) self.assertEqual(batch["labels"].dtype, torch.long) @@ -39,12 +39,24 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["inputs"].shape, torch.Size([8, 6])) # Features can already be tensors - features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)] + features = [{"label": i, "inputs": torch.randint(10, [10])} for i in range(8)] batch = default_data_collator(features) self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8))))) self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["inputs"].shape, torch.Size([8, 10])) + def test_default_with_no_labels(self): + features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] + batch = default_data_collator(features) + self.assertTrue("labels" not in batch) + self.assertEqual(batch["inputs"].shape, torch.Size([8, 6])) + + # With label_ids + features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] + batch = default_data_collator(features) + self.assertTrue("labels" not in batch) + self.assertEqual(batch["inputs"].shape, torch.Size([8, 6])) + def test_default_classification(self): MODEL_ID = "bert-base-cased-finetuned-mrpc" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)