Make DataCollator a callable (#5015)
* Make DataCollator a callable * Update src/transformers/data/data_collator.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NewType, Tuple
|
||||
from typing import Any, Callable, Dict, List, NewType, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
class DataCollator(ABC):
|
||||
"""
|
||||
A `DataCollator` is responsible for batching
|
||||
and pre-processing samples of data as requested by the training loop.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def collate_batch(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Take a list of samples from a Dataset and collate them into a batch.
|
||||
|
||||
Returns:
|
||||
A dictionary of tensors
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
InputDataClass = NewType("InputDataClass", Any)
|
||||
|
||||
"""
|
||||
A DataCollator is a function that takes a list of samples from a Dataset
|
||||
and collate them into a batch, as a dictionary of Tensors.
|
||||
"""
|
||||
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
|
||||
|
||||
@dataclass
|
||||
class DefaultDataCollator(DataCollator):
|
||||
|
||||
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Very simple data collator that:
|
||||
- simply collates batches of dict-like objects
|
||||
@@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator):
|
||||
See glue and ner for example of how it's useful.
|
||||
"""
|
||||
|
||||
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
||||
# In this method we'll make the assumption that all `features` in the batch
|
||||
# have the same attributes.
|
||||
# So we will look at the first element as a proxy for what attributes exist
|
||||
# on the whole batch.
|
||||
first = features[0]
|
||||
# In this function we'll make the assumption that all `features` in the batch
|
||||
# have the same attributes.
|
||||
# So we will look at the first element as a proxy for what attributes exist
|
||||
# on the whole batch.
|
||||
first = features[0]
|
||||
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if hasattr(first, "label") and first.label is not None:
|
||||
if type(first.label) is int:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
||||
if type(first.label_ids[0]) is int:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if hasattr(first, "label") and first.label is not None:
|
||||
if type(first.label) is int:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
else:
|
||||
batch = {}
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
||||
if type(first.label_ids[0]) is int:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
else:
|
||||
batch = {}
|
||||
|
||||
# Handling of all other possible attributes.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in vars(first).items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
||||
return batch
|
||||
# Handling of all other possible attributes.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in vars(first).items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForLanguageModeling(DataCollator):
|
||||
class DataCollatorForLanguageModeling:
|
||||
"""
|
||||
Data collator used for language modeling.
|
||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||
@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
batch = self._tensorize_batch(examples)
|
||||
if self.mlm:
|
||||
inputs, labels = self.mask_tokens(batch)
|
||||
|
||||
Reference in New Issue
Block a user