From 20fa82898495f516b221115fc3ef9ec8ebf50b1e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 17 Jun 2020 15:24:51 -0400 Subject: [PATCH] Make default_data_collator more flexible and deprecate old behavior (#5060) * Make default_data_collator more flexible * Accept tensors for all features * Document code * Refactor * Formatting --- src/transformers/data/data_collator.py | 35 ++++++++++++++------------ src/transformers/trainer.py | 10 ++++++++ tests/test_trainer.py | 21 ++++++++++++++++ 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 629a8b0a6e..5e014d338b 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten # have the same attributes. # So we will look at the first element as a proxy for what attributes exist # on the whole batch. + if not isinstance(features[0], dict): + features = [vars(f) for f in features] + first = features[0] + batch = {} # Special handling for labels. # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) - if hasattr(first, "label") and first.label is not None: - if type(first.label) is int: - labels = torch.tensor([f.label for f in features], dtype=torch.long) + if "label" in first: + dtype = torch.long if type(first["label"]) is int else torch.float + batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) + elif "label_ids" in first: + if isinstance(first["label_ids"], torch.Tensor): + batch["labels"] = torch.stack([f["label_ids"] for f in features]) else: - labels = torch.tensor([f.label for f in features], dtype=torch.float) - batch = {"labels": labels} - elif hasattr(first, "label_ids") and first.label_ids is not None: - if type(first.label_ids[0]) is int: - labels = torch.tensor([f.label_ids for f in features], dtype=torch.long) - else: - labels = torch.tensor([f.label_ids for f in features], dtype=torch.float) - batch = {"labels": labels} - else: - batch = {} + dtype = torch.long if type(first["label_ids"][0]) is int else torch.float + batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) - # Handling of all other possible attributes. + # Handling of all other possible keys. # Again, we will use the first element to figure out which key/values are not None for this model. - for k, v in vars(first).items(): + for k, v in first.items(): if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): - batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + else: + batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long) + return batch diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 20a4e0a71c..d1eb9ae7c4 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4,6 +4,7 @@ import os import random import re import shutil +import warnings from contextlib import contextmanager from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple @@ -205,6 +206,15 @@ class Trainer: # Set an xla_device flag on the model's config. # We'll find a more elegant and not need to do this in the future. self.model.config.xla_device = True + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + self.data_collator = self.data_collator.collate_batch + warnings.warn( + ( + "The `data_collator` should now be a simple callable (function, class with `__call__`), classes " + + "with a `collate_batch` are deprecated and won't be supported in a future version." + ), + FutureWarning, + ) def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 89d6a77918..47cfa89918 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" @require_torch class DataCollatorIntegrationTest(unittest.TestCase): + def test_default_with_dict(self): + features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] + batch = default_data_collator(features) + self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8))))) + self.assertEqual(batch["labels"].dtype, torch.long) + self.assertEqual(batch["inputs"].shape, torch.Size([8, 6])) + + # With label_ids + features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] + batch = default_data_collator(features) + self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8))) + self.assertEqual(batch["labels"].dtype, torch.long) + self.assertEqual(batch["inputs"].shape, torch.Size([8, 6])) + + # Features can already be tensors + features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)] + batch = default_data_collator(features) + self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8))))) + self.assertEqual(batch["labels"].dtype, torch.long) + self.assertEqual(batch["inputs"].shape, torch.Size([8, 10])) + def test_default_classification(self): MODEL_ID = "bert-base-cased-finetuned-mrpc" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)