From d2370e1bd8513dcdb92599292ba09ecaa5e68c86 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 12 Aug 2020 11:32:27 -0400 Subject: [PATCH] Adding PaddingDataCollator (#6442) * Data collator with padding * Add type annotation * Support tensors as well * Add comment * Fix for labels wrong shape * Data collator with padding * Add type annotation * Support tensors as well * Add comment * Fix for labels wrong shape * Remove changes rendered unnecessary --- src/transformers/__init__.py | 1 + src/transformers/data/data_collator.py | 54 +++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 317d66baea..956a03f283 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -438,6 +438,7 @@ if is_torch_available(): DataCollator, DataCollatorForLanguageModeling, DataCollatorForPermutationLanguageModeling, + DataCollatorWithPadding, ) from .data.datasets import ( GlueDataset, diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index cf8eb996f8..193545eea3 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict, List, NewType, Tuple, Union +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union import torch from torch.nn.utils.rnn import pad_sequence from ..tokenization_utils import PreTrainedTokenizer -from ..tokenization_utils_base import BatchEncoding +from ..tokenization_utils_base import BatchEncoding, PaddingStrategy +from ..tokenization_utils_fast import PreTrainedTokenizerFast InputDataClass = NewType("InputDataClass", Any) @@ -66,6 +67,55 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten return batch +@dataclass +class DataCollatorWithPadding: + """ + Data collator that will dynamically pad the inputs received. + + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + >= 7.5 (Volta). + """ + + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + batch = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + if "label" in batch: + batch["labels"] = batch["label"] + del batch["label"] + if "label_ids" in batch: + batch["labels"] = batch["label_ids"] + del batch["label_ids"] + return batch + + @dataclass class DataCollatorForLanguageModeling: """