Adding PaddingDataCollator (#6442)
* Data collator with padding * Add type annotation * Support tensors as well * Add comment * Fix for labels wrong shape * Data collator with padding * Add type annotation * Support tensors as well * Add comment * Fix for labels wrong shape * Remove changes rendered unnecessary
This commit is contained in:
@@ -438,6 +438,7 @@ if is_torch_available():
|
|||||||
DataCollator,
|
DataCollator,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
DataCollatorForPermutationLanguageModeling,
|
DataCollatorForPermutationLanguageModeling,
|
||||||
|
DataCollatorWithPadding,
|
||||||
)
|
)
|
||||||
from .data.datasets import (
|
from .data.datasets import (
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, NewType, Tuple, Union
|
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
from ..tokenization_utils_base import BatchEncoding
|
from ..tokenization_utils_base import BatchEncoding, PaddingStrategy
|
||||||
|
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
InputDataClass = NewType("InputDataClass", Any)
|
InputDataClass = NewType("InputDataClass", Any)
|
||||||
@@ -66,6 +67,55 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorWithPadding:
|
||||||
|
"""
|
||||||
|
Data collator that will dynamically pad the inputs received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||||
|
The tokenizer used for encoding the data.
|
||||||
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||||
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||||
|
index) among:
|
||||||
|
|
||||||
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||||
|
single sequence if provided).
|
||||||
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||||
|
maximum acceptable input length for the model if that argument is not provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||||
|
different lengths).
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
Maximum length of the returned list and optionally padding length (see above).
|
||||||
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||||
|
If set will pad the sequence to a multiple of the provided value.
|
||||||
|
|
||||||
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
||||||
|
>= 7.5 (Volta).
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
batch = self.tokenizer.pad(
|
||||||
|
features,
|
||||||
|
padding=self.padding,
|
||||||
|
max_length=self.max_length,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
if "label" in batch:
|
||||||
|
batch["labels"] = batch["label"]
|
||||||
|
del batch["label"]
|
||||||
|
if "label_ids" in batch:
|
||||||
|
batch["labels"] = batch["label_ids"]
|
||||||
|
del batch["label_ids"]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForLanguageModeling:
|
class DataCollatorForLanguageModeling:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user