Fix tensor label type inference in default collator (#5250)

* allow tensor label inputs to default collator

* replace try/except with type check
This commit is contained in:
Joe Davison
2020-07-01 10:40:14 -06:00
committed by GitHub
parent fe81f7d12c
commit 35befd9ce3
2 changed files with 10 additions and 1 deletions

View File

@@ -44,6 +44,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
# Labels can already be tensors
features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_with_no_labels(self):
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)