Make DataCollator a callable (#5015)

* Make DataCollator a callable

* Update src/transformers/data/data_collator.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-06-15 11:58:33 -04:00
committed by GitHub
parent f7c93b3cee
commit 1affde2f10
7 changed files with 60 additions and 83 deletions

View File

@@ -38,7 +38,6 @@ from transformers import (
BertConfig, BertConfig,
BertForSequenceClassification, BertForSequenceClassification,
BertTokenizer, BertTokenizer,
DefaultDataCollator,
DistilBertConfig, DistilBertConfig,
DistilBertForSequenceClassification, DistilBertForSequenceClassification,
DistilBertTokenizer, DistilBertTokenizer,
@@ -51,6 +50,7 @@ from transformers import (
XLNetConfig, XLNetConfig,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetTokenizer, XLNetTokenizer,
default_data_collator,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from utils_hans import HansDataset, hans_output_modes, hans_processors from utils_hans import HansDataset, hans_output_modes, hans_processors
@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
) )
if args.max_steps > 0: if args.max_steps > 0:
@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
eval_dataset, eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
sampler=eval_sampler,
batch_size=args.eval_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
) )
# multi-gpu eval # multi-gpu eval

View File

@@ -34,8 +34,8 @@ from transformers import (
AutoConfig, AutoConfig,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
DefaultDataCollator,
GlueDataset, GlueDataset,
default_data_collator,
glue_compute_metrics, glue_compute_metrics,
glue_output_modes, glue_output_modes,
glue_processors, glue_processors,
@@ -424,7 +424,7 @@ def main():
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset))))) eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
) )
# Compute head entropy and importance score # Compute head entropy and importance score

View File

@@ -364,7 +364,7 @@ if is_torch_available():
# Trainer # Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# Benchmarks # Benchmarks

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Tuple from typing import Any, Callable, Dict, List, NewType, Tuple
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
class DataCollator(ABC):
"""
A `DataCollator` is responsible for batching
and pre-processing samples of data as requested by the training loop.
"""
@abstractmethod
def collate_batch(self) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
pass
InputDataClass = NewType("InputDataClass", Any) InputDataClass = NewType("InputDataClass", Any)
"""
A DataCollator is a function that takes a list of samples from a Dataset
and collate them into a batch, as a dictionary of Tensors.
"""
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
@dataclass
class DefaultDataCollator(DataCollator): def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
""" """
Very simple data collator that: Very simple data collator that:
- simply collates batches of dict-like objects - simply collates batches of dict-like objects
@@ -42,8 +29,7 @@ class DefaultDataCollator(DataCollator):
See glue and ner for example of how it's useful. See glue and ner for example of how it's useful.
""" """
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]: # In this function we'll make the assumption that all `features` in the batch
# In this method we'll make the assumption that all `features` in the batch
# have the same attributes. # have the same attributes.
# So we will look at the first element as a proxy for what attributes exist # So we will look at the first element as a proxy for what attributes exist
# on the whole batch. # on the whole batch.
@@ -76,7 +62,7 @@ class DefaultDataCollator(DataCollator):
@dataclass @dataclass
class DataCollatorForLanguageModeling(DataCollator): class DataCollatorForLanguageModeling:
""" """
Data collator used for language modeling. Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token - collates batches of tensors, honoring their tokenizer's pad_token
@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
mlm: bool = True mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
if self.mlm: if self.mlm:
inputs, labels = self.mask_tokens(batch) inputs, labels = self.mask_tokens(batch)

View File

@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DefaultDataCollator from .data.data_collator import DataCollator, default_data_collator
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
@@ -190,10 +190,7 @@ class Trainer:
""" """
self.model = model.to(args.device) self.model = model.to(args.device)
self.args = args self.args = args
if data_collator is not None: self.data_collator = data_collator if data_collator is not None else default_data_collator
self.data_collator = data_collator
else:
self.data_collator = DefaultDataCollator()
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
@@ -239,7 +236,7 @@ class Trainer:
self.train_dataset, self.train_dataset,
batch_size=self.args.train_batch_size, batch_size=self.args.train_batch_size,
sampler=train_sampler, sampler=train_sampler,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
@@ -264,7 +261,7 @@ class Trainer:
eval_dataset, eval_dataset,
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
@@ -285,7 +282,7 @@ class Trainer:
test_dataset, test_dataset,
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )

View File

@@ -11,7 +11,7 @@ if is_torch_available():
Trainer, Trainer,
LineByLineTextDataset, LineByLineTextDataset,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DefaultDataCollator, default_data_collator,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
) )
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator() data_collator = default_data_collator
batch = data_collator.collate_batch(dataset.features) batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self): def test_default_regression(self):
@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
) )
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator() data_collator = default_data_collator
batch = data_collator.collate_batch(dataset.features) batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float) self.assertEqual(batch["labels"].dtype, torch.float)
def test_lm_tokenizer_without_padding(self): def test_lm_tokenizer_without_padding(self):
@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2: # Expect error due to padding token missing on gpt2:
data_collator.collate_batch(examples) data_collator(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107))) self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107))) self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))

View File

@@ -29,7 +29,7 @@ if is_torch_available():
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from transformers import DataCollator, Trainer from transformers import Trainer
class DummyDataset(Dataset): class DummyDataset(Dataset):
def __init__(self, length: int = 101): def __init__(self, length: int = 101):
@@ -41,8 +41,8 @@ if is_torch_available():
def __getitem__(self, i) -> int: def __getitem__(self, i) -> int:
return i return i
class DummyDataCollator(DataCollator): class DummyDataCollator:
def collate_batch(self, features): def __call__(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)} return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module): class DummyModel(nn.Module):