From 35befd9ce31c23a774fd34f57bc44033ce70141d Mon Sep 17 00:00:00 2001 From: Joe Davison Date: Wed, 1 Jul 2020 10:40:14 -0600 Subject: [PATCH] Fix tensor label type inference in default collator (#5250) * allow tensor label inputs to default collator * replace try/except with type check --- src/transformers/data/data_collator.py | 3 ++- tests/test_trainer.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index acb0807af1..29331cc83d 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -43,7 +43,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) if "label" in first and first["label"] is not None: - dtype = torch.long if type(first["label"]) is int else torch.float + label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] + dtype = torch.long if isinstance(label, int) else torch.float batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) elif "label_ids" in first and first["label_ids"] is not None: if isinstance(first["label_ids"], torch.Tensor): diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 141c7128d2..d68eee524d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -44,6 +44,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["inputs"].shape, torch.Size([8, 10])) + # Labels can already be tensors + features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)] + batch = default_data_collator(features) + self.assertEqual(batch["labels"].dtype, torch.long) + self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8))))) + self.assertEqual(batch["labels"].dtype, torch.long) + self.assertEqual(batch["inputs"].shape, torch.Size([8, 10])) + def test_default_with_no_labels(self): features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] batch = default_data_collator(features)