Adding PaddingDataCollator (#6442)
* Data collator with padding * Add type annotation * Support tensors as well * Add comment * Fix for labels wrong shape * Data collator with padding * Add type annotation * Support tensors as well * Add comment * Fix for labels wrong shape * Remove changes rendered unnecessary
This commit is contained in:
@@ -438,6 +438,7 @@ if is_torch_available():
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorWithPadding,
|
||||
)
|
||||
from .data.datasets import (
|
||||
GlueDataset,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NewType, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..tokenization_utils_base import BatchEncoding
|
||||
from ..tokenization_utils_base import BatchEncoding, PaddingStrategy
|
||||
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
InputDataClass = NewType("InputDataClass", Any)
|
||||
@@ -66,6 +67,55 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithPadding:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
||||
>= 7.5 (Volta).
|
||||
"""
|
||||
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
batch = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if "label" in batch:
|
||||
batch["labels"] = batch["label"]
|
||||
del batch["label"]
|
||||
if "label_ids" in batch:
|
||||
batch["labels"] = batch["label_ids"]
|
||||
del batch["label_ids"]
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForLanguageModeling:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user