Make DataCollator a callable (#5015)

* Make DataCollator a callable

* Update src/transformers/data/data_collator.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-06-15 11:58:33 -04:00
committed by GitHub
parent f7c93b3cee
commit 1affde2f10
7 changed files with 60 additions and 83 deletions

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Tuple
from typing import Any, Callable, Dict, List, NewType, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer
class DataCollator(ABC):
"""
A `DataCollator` is responsible for batching
and pre-processing samples of data as requested by the training loop.
"""
@abstractmethod
def collate_batch(self) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
pass
InputDataClass = NewType("InputDataClass", Any)
"""
A DataCollator is a function that takes a list of samples from a Dataset
and collate them into a batch, as a dictionary of Tensors.
"""
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
@dataclass
class DefaultDataCollator(DataCollator):
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
"""
Very simple data collator that:
- simply collates batches of dict-like objects
@@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator):
See glue and ner for example of how it's useful.
"""
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
# In this method we'll make the assumption that all `features` in the batch
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
first = features[0]
# In this function we'll make the assumption that all `features` in the batch
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
first = features[0]
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
batch = {}
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
# Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch
# Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch
@dataclass
class DataCollatorForLanguageModeling(DataCollator):
class DataCollatorForLanguageModeling:
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
mlm: bool = True
mlm_probability: float = 0.15
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = self._tensorize_batch(examples)
if self.mlm:
inputs, labels = self.mask_tokens(batch)